diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 6d2ec56dc4..deaf99aa39 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -732,6 +732,7 @@ def refit_policy_generation( _refit_buffer_size_gb: The size of the buffer to use for refitting. If it is None, the buffer size will be computed by the remaining memory. This parameter is primarily used for testing. + timer: Optional Timer used to time the prepare/transfer/update phase """ if colocated_inference: policy.offload_before_refit() @@ -748,22 +749,24 @@ def refit_policy_generation( update_success = False if colocated_inference: # get model param keys, which is grouped by size - grouped_param_keys = policy.prepare_weights_for_ipc( - _refit_buffer_size_gb=_refit_buffer_size_gb - ) - total_num_keys = sum(len(k) for k in grouped_param_keys) - print( - f"[Refit] Split {total_num_keys} keys into {len(grouped_param_keys)} groups", - flush=True, - ) - # do update - for keys in grouped_param_keys: - ipc_handles = policy.get_weights_ipc_handles(keys) - update_success = policy_generation.update_weights_from_ipc_handles( - ipc_handles + if _refit_buffer_size_gb is not None: + buffer_size_bytes = _refit_buffer_size_gb * (1024**3) + else: + # Empirically sets ratio as 30% to maximize efficiency. + # The remaining 70% is a necessary buffer reserved for the parameter all-gathering across the expert-parallelism dimension. + memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.3") + buffer_size_bytes = int( + policy.get_free_memory_bytes() * float(memory_ratio) ) - if not update_success: - break + + futures_train = policy.stream_weights_via_ipc_zmq( + buffer_size_bytes=buffer_size_bytes + ) + futures_inference = policy_generation.update_weights_via_ipc_zmq() + # wait for all futures to complete + ray.get(futures_train) + results = ray.get(futures_inference) + update_success = all(result for result in results if result is not None) else: # update weights through nccl futures_train = policy.broadcast_weights_for_collective() diff --git a/nemo_rl/models/generation/interfaces.py b/nemo_rl/models/generation/interfaces.py index d424c7c1df..12e5aecbc1 100644 --- a/nemo_rl/models/generation/interfaces.py +++ b/nemo_rl/models/generation/interfaces.py @@ -233,7 +233,7 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: """Prepare the info for refit.""" raise NotImplementedError - def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool: + def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]: """Update the model weights from the given IPC handles.""" raise NotImplementedError diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 14db473ae9..0bf2b224d3 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -11,12 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import defaultdict -from typing import Any, Optional +import gc +from typing import Any import torch -from torch.multiprocessing.reductions import rebuild_cuda_tensor +import zmq +from nemo_rl.models.policy.utils import ( + IPCProtocol, + calculate_aligned_size, + rebuild_cuda_tensor_from_ipc, +) from nemo_rl.utils.nsys import wrap_with_nvtx_name from nemo_rl.utils.packed_tensor import packed_broadcast_consumer @@ -56,124 +61,107 @@ def init_collective( ) def report_device_id(self) -> str: + """Retrieve the UUID of the current CUDA device.""" from nemo_rl.utils.nvml import get_device_uuid return get_device_uuid(self.device.index) - def prepare_refit_info( - self, state_dict_info: Optional[dict[str, Any]] = None - ) -> None: - """Prepare the info for refit. - - DtensorPolicyWorker: - colocated inference: state_dict_info is None - non-colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype)} + def get_zmq_address(self): + """Get the ZMQ address for the current device.""" + return f"ipc:///tmp/{self.report_device_id()}.sock" - MegatronPolicyWorker: - colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype, numel)} - non-colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype)} - """ - self.state_dict_info = state_dict_info # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored + def maybe_init_zmq(self): + """Initialize the ZMQ socket if it doesn't exist.""" + if not hasattr(self, "zmq_socket"): + self.zmq_context = zmq.Context() # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored + self.zmq_socket = self.zmq_context.socket( # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored + zmq.REP + ) + self.zmq_socket.setsockopt(zmq.SNDTIMEO, 30000) # set timeout to 30 seconds + self.zmq_socket.setsockopt(zmq.RCVTIMEO, 30000) # set timeout to 30 seconds + self.zmq_socket.setsockopt(zmq.LINGER, 0) + self.zmq_socket.connect(self.get_zmq_address()) - @wrap_with_nvtx_name( - "vllm_internal_worker_extension/update_weights_from_global_ipc_handles" - ) - def update_weights_from_global_ipc_handles(self, global_device_ipc_handles): - """Update weights from global IPC handles. + def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: + """Prepare state dict metadata for weight refitting and IPC streaming. Args: - global_device_ipc_handles (dict): Dictionary mapping device UUIDs to parameter IPC handles. - - Returns: - bool: True if weights were successfully updated. + state_dict_info (dict): A dictionary containing the info for refit. + e.g. {tensor_name: (shape, dtype)} """ - device_uuid = self.report_device_id() - local_device_ipc_handles = global_device_ipc_handles[device_uuid] - return self.update_weights_from_local_ipc_handles(local_device_ipc_handles) - - @wrap_with_nvtx_name( - "vllm_internal_worker_extension/update_weights_from_local_ipc_handles" - ) - def update_weights_from_local_ipc_handles(self, local_device_ipc_handles): - """Update weights from local IPC handles. + self.state_dict_info = state_dict_info # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored - Args: - local_device_ipc_handles (dict): parameter IPC handles for local device. + @wrap_with_nvtx_name("vllm_internal_worker_extension/update_weights_via_ipc_zmq") + def update_weights_via_ipc_zmq(self) -> bool: + """Receive and update model weights via ZMQ IPC socket. Returns: bool: True if weights were successfully updated. """ - try: - is_tensor_packed = local_device_ipc_handles[0] - if is_tensor_packed: - _, all_handles, list_keys = local_device_ipc_handles - else: - _, name_and_handle_list = local_device_ipc_handles + buffer = None + weights = None - device_id = self.device.index - weights = [] + try: + self.maybe_init_zmq() + while True: + # Blocking receive with timeout (this is the main operation) + payload = self.zmq_socket.recv_pyobj() - if is_tensor_packed: - assert self.state_dict_info is not None, ( - "state_dict_info is not prepared. " - "Please call prepare_refit_info when initializing the worker." - ) + if payload == IPCProtocol.COMPLETE: + # means the update is done + self.zmq_socket.send(IPCProtocol.ACK.value.encode()) + break - # Extract packed tensor from IPC handle - dtype_to_packed_tensor = {} - for dtype, tensor_handle in all_handles: - func = rebuild_cuda_tensor - args = tensor_handle[0] - list_args = list(args) - list_args[6] = device_id - tensor = func(*list_args) - dtype_to_packed_tensor[dtype] = tensor + ipc_handle, list_keys, used_bytes = payload + buffer = rebuild_cuda_tensor_from_ipc(ipc_handle, self.device.index) weights = [] - dtype_to_offset = defaultdict(lambda: 0) + offset = 0 for key in list_keys: - shape, dtype, size = self.state_dict_info[key] + shape, dtype = self.state_dict_info[key] # pyrefly + if isinstance(shape, list): + shape = torch.Size(shape) + size_in_bytes = dtype.itemsize * shape.numel() weights.append( ( key, - dtype_to_packed_tensor[dtype][ - dtype_to_offset[dtype] : dtype_to_offset[dtype] + size - ].view(*shape), + buffer[offset : offset + size_in_bytes] + .view(dtype=dtype) + .view(shape), ) ) - dtype_to_offset[dtype] += size - - expected_sizes = { - dtype: tensor.numel() - for dtype, tensor in dtype_to_packed_tensor.items() - } - assert dtype_to_offset == expected_sizes, ( - f"Packed tensor size mismatch: expected sizes from keys list {expected_sizes} != actual packed tensor sizes {dtype_to_offset}. " - f"This indicates the keys list order doesn't match the order used when packing tensors." + aligned_size = calculate_aligned_size(size_in_bytes) + offset += aligned_size + assert offset == used_bytes, ( + "Offset is not equal to used bytes, usually indicate inaccurate info like keys or cached dtype in state_dict_info" ) - else: - # Process each handle to get the tensor - for name, handle in name_and_handle_list: - func = rebuild_cuda_tensor - args = handle[0] - list_args = list(args) - list_args[6] = device_id - tensor = func(*list_args) - weights.append((name, tensor)) - - # Load weights into the model - from nemo_rl.models.generation import fp8 - - if fp8.is_fp8_model(self.model_runner.vllm_config): - # the fp8 load_weights additionally casts bf16 weights into fp8 - fp8.load_weights(weights, self.model_runner) - else: - self.model_runner.model.load_weights(weights=weights) - + # Load weights into the model + from nemo_rl.models.generation import fp8 + + if fp8.is_fp8_model(self.model_runner.vllm_config): + # the fp8 load_weights additionally casts bf16 weights into fp8 + fp8.load_weights(weights, self.model_runner) + else: + self.model_runner.model.load_weights(weights=weights) + + torch.cuda.current_stream().synchronize() + + # CRITICAL: Delete views before ACK to prevent corruption. + # 'weights' contains views into IPC shared memory. Even though load_weights() + # copied the data, Python may not garbage collect these view objects immediately. + # If sender reuses the buffer before GC runs, old views would read corrupted data. + # Explicit del ensures immediate cleanup before sending ACK. + del weights, buffer + weights = None + buffer = None + self.zmq_socket.send(IPCProtocol.ACK.value.encode()) + + gc.collect() + torch.cuda.empty_cache() return True except Exception as e: print( - f"Error in VllmInternalWorkerExtension.update_weights_from_ipc_handles: {e}" + f"Error in VllmInternalWorkerExtension.update_weights_via_ipc_zmq: {e}" ) return False @@ -222,6 +210,13 @@ def _load_model_weights(weights, model_runner): return True + def cleanup(self) -> None: + """Shutdown and cleanup resources.""" + # Close ZMQ socket and context if they exist + if hasattr(self, "zmq_socket"): + self.zmq_socket.close() + self.zmq_context.term() + def start_gpu_profiling(self) -> None: """Start GPU profiling.""" torch.cuda.profiler.start() diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index 29a222ba29..f43adf07cf 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -763,49 +763,26 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: # Wait for all futures to complete ray.get(futures) - def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool: - """Update weights of the policy using IPC handles, considering tensor parallelism. - - For tp > 1, only the leader in each tensor parallel tied worker group will update weights. - - Args: - ipc_handles (dict): Dictionary mapping device UUIDs (str) to parameter IPC handles. - - Returns: - bool: True if weights were successfully updated, False otherwise. - """ + def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]: + """Update weights of the policy using IPC handles via ZMQ socket.""" if not self.worker_group or not self.worker_group.workers: - return False + raise RuntimeError("Worker group is not initialized") # Choose the appropriate method based on async_engine setting method_name = ( - "update_weights_from_ipc_handles_async" + "update_weights_via_ipc_zmq_async" if self.cfg["vllm_cfg"]["async_engine"] - else "update_weights_from_ipc_handles" + else "update_weights_via_ipc_zmq" ) - # Only send the ipc handles required by the current worker - ipc_handles_list = [] - for worker_device_uuids in self.device_uuids: - worker_ipc_handles = { - device_uuid: ipc_handles[device_uuid] - for device_uuid in worker_device_uuids - } - ipc_handles_list.append(worker_ipc_handles) + # Use run_all_workers_single_data since no data needs to be passed + futures = self.worker_group.run_all_workers_single_data( + method_name, + run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], + ) - try: - # Directly pass ipc_handles to the method - futures = self.worker_group.run_all_workers_multiple_data( - method_name, - ipc_handles=ipc_handles_list, - run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], - ) - # Wait for all futures to complete - results = ray.get(futures) - return all(result for result in results if result is not None) - except Exception as e: - print(f"Error during update weights: {e}") - return False + # this function should co-work with lm_policy, so we should wait for all futures to complete outside + return futures def update_weights_from_collective(self) -> list[ray.ObjectRef]: """Update weights of the policy using collective communication.""" diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index 78d8505632..7ce826a27f 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -706,16 +706,9 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: """Prepare the info for refit.""" self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,)) - @wrap_with_nvtx_name("vllm_genertion_worker/update_weights_from_ipc_handles") - def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool: - """Update weights from IPC handles by delegating to the vLLM Worker implementation. - - Args: - ipc_handles (dict): Dictionary mapping device UUIDs (str) to parameter IPC handles. - - Returns: - bool: True if weights were successfully updated, False otherwise. - """ + @wrap_with_nvtx_name("vllm_genertion_worker/update_weights_via_ipc_zmq") + def update_weights_via_ipc_zmq(self) -> bool: + """Update weights from IPC handles via ZMQ socket.""" try: assert self.llm is not None, ( "Attempting to update weights with either an uninitialized vLLM or non-model-owner" @@ -723,40 +716,13 @@ def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool: if self.cfg["vllm_cfg"]["async_engine"]: raise RuntimeError( - "update_weights_from_ipc_handles cannot be used with async_engine=True. Use update_weights_from_ipc_handles_async instead." + "update_weights_via_ipc_zmq cannot be used with async_engine=True. Use update_weights_via_ipc_zmq_async instead." ) - if self.tensor_parallel_size == 1: - # UniProcExecutor - assert len(self.vllm_device_ids) == 1 - result_or_coro = self.llm.collective_rpc( - "update_weights_from_local_ipc_handles", - args=(ipc_handles[self.vllm_device_ids[0]],), - ) - else: - """ - DO NOT USE VLLM's collective_rpc: This code causes duplicate IPC data transfer across Ray workers, - leading to unnecessary network serialization overhead and potential performance degradation. - - result_or_coro = self.llm.collective_rpc( - "update_weights_from_global_ipc_handles", args=(ipc_handles,) - ) - """ - ray_worker_outputs = [] - # MultiProcExecutor - for worker, device_id in zip( - self.llm.llm_engine.model_executor.workers, self.vllm_device_ids - ): - ray_worker_outputs.append( - worker.execute_method.remote( - "update_weights_from_local_ipc_handles", - ipc_handles[device_id], - ) - ) - - # Gather the results - result_or_coro = ray.get(ray_worker_outputs) - + result_or_coro = self.llm.collective_rpc( + "update_weights_via_ipc_zmq", + args=tuple(), + ) worker_result = result_or_coro[0] if not worker_result: @@ -859,6 +825,9 @@ def shutdown(self) -> bool: """Clean up vLLM resources.""" try: if self.llm is not None: + # Clean up extension resources (e.g., ZMQ sockets) + self.llm.collective_rpc("cleanup", args=tuple()) + # Explicitly delete the engine. This may trigger its __del__ method. del self.llm diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index b01db28aab..c456c62c03 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -850,17 +850,10 @@ async def prepare_refit_info_async(self, state_dict_info: dict[str, Any]) -> Non """Async version of prepare_refit_info.""" await self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,)) - async def update_weights_from_ipc_handles_async( - self, ipc_handles: dict[str, Any] + async def update_weights_via_ipc_zmq_async( + self, ) -> bool: - """Async version of update_weights_from_ipc_handles. - - Args: - ipc_handles (dict): Dictionary mapping device UUIDs (str) to parameter IPC handles. - - Returns: - bool: True if weights were successfully updated, False otherwise. - """ + """Async version of update_weights_via_ipc_zmq.""" try: assert self.llm is not None, ( "Attempting to update weights with either an uninitialized vLLM or non-model-owner" @@ -868,12 +861,12 @@ async def update_weights_from_ipc_handles_async( if not self.cfg["vllm_cfg"]["async_engine"]: raise RuntimeError( - "update_weights_from_ipc_handles_async can only be used with async_engine=True. Use update_weights_from_ipc_handles instead." + "update_weights_via_ipc_zmq_async can only be used with async_engine=True. Use update_weights_via_ipc_zmq instead." ) # TODO: switch to update_weights_from_local_ipc_handles for better performance once collectively report_device_id is supported in asyncLLM initialization result_or_coro = await self.llm.collective_rpc( - "update_weights_from_global_ipc_handles", args=(ipc_handles,) + "update_weights_via_ipc_zmq", args=tuple() ) if asyncio.iscoroutine(result_or_coro): @@ -984,10 +977,12 @@ async def wake_up_async(self, **kwargs): await self.llm.wake_up(**wake_up_args) - def shutdown(self) -> bool: + async def shutdown(self) -> bool: """Clean up vLLM resources.""" try: if self.llm is not None: + # Clean up extension resources (e.g., ZMQ sockets) + await self.llm.collective_rpc("cleanup", args=tuple()) try: self.llm.shutdown() except Exception as e_stop: diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index c2f2f5a006..34ad4e16d5 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -23,6 +23,7 @@ import ray import torch +import zmq from accelerate import init_empty_weights from torch import nn from torch.distributed.checkpoint.state_dict import ( @@ -72,7 +73,6 @@ from nemo_rl.models.policy.utils import ( configure_dynamo_cache, get_gpu_info, - get_handle_from_tensor, get_runtime_env_for_policy_worker, import_class_from_path, resolve_model_class, @@ -159,6 +159,7 @@ def __init__( init_reference_model: bool = True, **kwargs: Any, ): + """Initialize the DTensorPolicyWorker.""" self.tokenizer = tokenizer self.processor = processor self.is_vlm = processor is not None @@ -178,15 +179,6 @@ def __init__( # with different order of node_bundles configure_dynamo_cache() - # vars used for refit - ## will be initialized in prepare_refit_info - self.refit_param_info = None - ## used for streaming update inference engine weights - self._held_sharded_state_dict_reference: Optional[dict[str, torch.Tensor]] = ( - None - ) - self._held_streamed_param_reference: Optional[dict[str, torch.Tensor]] = None - self.cfg = config # torch distributed init. Envars for rank, world_size, and master_addr and master_port are set from the ray remote call torch.distributed.init_process_group(backend="nccl") @@ -1704,100 +1696,72 @@ def report_device_id(self) -> str: # Get device UUID using NVML return get_device_uuid(device_idx) - @torch.no_grad() - def prepare_refit_info(self) -> Optional[dict[str, Any]]: - state_dict = self.model.state_dict() + def get_zmq_address(self): + """Get the ZMQ address for the current device.""" + return f"ipc:///tmp/{self.report_device_id()}.sock" - if self.is_generation_colocated: - # Collect info for streaming multiple tensors - self.refit_param_info = [] - for name, tensor in state_dict.items(): - # dtensor's numel will return complete tensor instead of only local tensor - size_in_bytes = tensor.element_size() * tensor.numel() - self.refit_param_info.append((name, size_in_bytes)) - - else: - # Collect info for collective communication - state_dict_info = {} - for name, tensor in state_dict.items(): - state_dict_info[name] = (tensor.shape, self.dtype) - - return state_dict_info + def maybe_init_zmq(self): + """Initialize the ZMQ socket if it doesn't exist.""" + if not hasattr(self, "zmq_socket"): + self.zmq_context = zmq.Context() + self.zmq_socket = self.zmq_context.socket(zmq.REQ) + self.zmq_socket.setsockopt(zmq.SNDTIMEO, 30000) # set timeout to 30 seconds + self.zmq_socket.setsockopt(zmq.RCVTIMEO, 30000) # set timeout to 30 seconds + self.zmq_socket.setsockopt(zmq.LINGER, 0) + self.zmq_socket.bind(self.get_zmq_address()) @torch.no_grad() - def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: - """Prepare the weights for IPC. + def prepare_refit_info(self) -> Optional[dict[str, Any]]: + """Prepare state dict metadata for weight refitting and IPC streaming.""" + state_dict_info = {} + for name, tensor in self.model.state_dict().items(): + # all tensor will be casted to self.dtype in stream_weights_via_ipc_zmq/broadcast_weights_for_collective + state_dict_info[name] = (tensor.shape, self.dtype) - This function: - - Prepares the state_dict of the model. - - Collects the info for streaming multiple tensors. + return state_dict_info - Returns: - list: The list of parameters sizes. - float: The total available memory in bytes. - """ + def get_free_memory_bytes(self) -> int: + """Get the available free memory.""" from nemo_rl.utils.nvml import get_free_memory_bytes + device_idx = torch.cuda.current_device() + return get_free_memory_bytes(device_idx) + + @torch.no_grad() + @wrap_with_nvtx_name("dtensor_policy_worker/stream_weights_via_ipc_zmq") + def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int = 0) -> None: + """Stream model weights to peer process via ZMQ IPC socket.""" + self.maybe_init_zmq() # Manually move model to cuda for cpu offload case if self.cpu_offload: self.model = self.move_to_cuda(self.model) - # Get state_dict - self._held_sharded_state_dict_reference: dict[str, torch.Tensor] = ( - self.model.state_dict() - ) - - # Collect current available memory for refit - ## Get current device index from torch - device_idx = torch.cuda.current_device() - ## Get device free memory using NVML - total_available_bytes = get_free_memory_bytes(device_idx) - ## Use 80% of the free memory for safety - memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.8") - total_available_bytes *= float(memory_ratio) - - return self.refit_param_info, total_available_bytes - - @torch.no_grad() - @wrap_with_nvtx_name("dtensor_policy_worker/get_weights_ipc_handles") - def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]: - assert self._held_sharded_state_dict_reference is not None, ( - "prepare_weights_for_ipc must be called before get_weights_ipc_handles" + from nemo_rl.models.policy.utils import stream_weights_via_ipc_zmq_impl + + def dtensor_params_generator(): + """Generator that yields (name, tensor) pairs, converting DTensors to local tensors.""" + for name, tensor in self.model.state_dict().items(): + if isinstance(tensor, DTensor): + # Convert DTensor to full tensor for streaming + full_tensor = tensor.full_tensor() + # Convert to target dtype + yield ( + name, + full_tensor.to(self.dtype, non_blocking=True).contiguous(), + ) + else: + # Convert to target dtype + yield name, tensor.to(self.dtype, non_blocking=True).contiguous() + + # Use the shared implementation + stream_weights_via_ipc_zmq_impl( + params_generator=dtensor_params_generator(), + buffer_size_bytes=buffer_size_bytes, + zmq_socket=self.zmq_socket, + rank=self.rank, + worker_name=str(self), ) - # Clean up the held tensors to reduce peak memory - if self._held_streamed_param_reference is not None: - del self._held_streamed_param_reference - self._held_streamed_param_reference = None - - converted_params = {} - for key in keys: - # Get full_tensor for dtensor (GPU > 1) - tensor = self._held_sharded_state_dict_reference[key] - if isinstance(tensor, DTensor): - full_tensor = tensor.full_tensor() - else: - full_tensor = tensor - # Convert parameters to the configured dtype - converted_params[key] = full_tensor.to(self.dtype, non_blocking=True) - - # Temporary record the full tensor for cleanup - # It is needed for cleanup the last full_tensor in the refit process - self._held_streamed_param_reference = converted_params - - # Get device UUID for IPC - device_uuid = self.report_device_id() - # Create handles for the tensors - all_handles = [] - for key, p in converted_params.items(): - handle = get_handle_from_tensor(p) - all_handles.append((key, handle)) - - # (pack_tensor_for_ipc: bool, handles: list) - serialized = (False, all_handles) - - return {device_uuid: serialized} - @torch.no_grad() def broadcast_weights_for_collective(self) -> None: """Broadcast the weights for collective communication.""" @@ -1881,23 +1845,12 @@ def offload_before_refit(self) -> None: @torch.no_grad() @wrap_with_nvtx_name("dtensor_policy_worker/offload_after_refit") def offload_after_refit(self) -> None: - # Offload as much as possible on the CPU + """Offload as much as possible on the CPU.""" self.model = self.move_to_cpu(self.model) self.model.eval() torch.randn(1).cuda() # wake up torch allocator self.offload_before_refit() # rerun the old offload function - # Clean up the held tensors - if self._held_sharded_state_dict_reference is not None: - del self._held_sharded_state_dict_reference - self._held_sharded_state_dict_reference = None - if self._held_streamed_param_reference is not None: - del self._held_streamed_param_reference - self._held_streamed_param_reference = None - - gc.collect() - torch.cuda.empty_cache() - # Print memory stats after offloading allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB reserved = torch.cuda.memory_reserved() / (1024**3) # Convert to GB @@ -1964,6 +1917,10 @@ def load_checkpoint( def shutdown(self) -> None: """Shutdown the policy.""" + # Clean up extension resources like ZMQ sockets + if hasattr(self, "zmq_socket"): + self.zmq_socket.close() + self.zmq_context.term() def start_gpu_profiling(self) -> None: """Start GPU profiling.""" diff --git a/nemo_rl/models/policy/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/dtensor_policy_worker_v2.py index fbd394bf80..7b60673e7f 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/dtensor_policy_worker_v2.py @@ -18,10 +18,11 @@ import warnings from collections import defaultdict from contextlib import AbstractContextManager, contextmanager, nullcontext -from typing import Any, Generator, Iterable, Optional, cast +from typing import Any, Generator, Optional, cast import ray import torch +import zmq from accelerate import init_empty_weights from nemo_automodel import ( NeMoAutoModelForSequenceClassification, @@ -83,7 +84,6 @@ from nemo_rl.models.policy.utils import ( configure_dynamo_cache, get_gpu_info, - get_handle_from_tensor, get_runtime_env_for_policy_worker, import_class_from_path, resolve_model_class, @@ -122,6 +122,7 @@ def __init__( init_reference_model: bool = True, **kwargs: Any, ): + """Initialize the DTensorPolicyWorkerV2.""" self.tokenizer = tokenizer self.processor = processor self.is_vlm = processor is not None @@ -449,15 +450,6 @@ def __init__( "No weights path provided. Starting from scratch (default policy init)" ) - # vars used for refit - ## will be initialized in prepare_refit_info - self.refit_param_info = None - ## used for streaming update inference engine weights - self._held_sharded_state_dict_reference: Optional[dict[str, torch.Tensor]] = ( - None - ) - self._held_streamed_param_reference: Optional[dict[str, torch.Tensor]] = None - def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor: if "generation" in self.cfg and self.cfg["generation"] is not None: logits.div_(self.cfg["generation"]["temperature"]) @@ -1665,100 +1657,72 @@ def report_device_id(self) -> str: # Get device UUID using NVML return get_device_uuid(device_idx) - @torch.no_grad() - def prepare_refit_info(self) -> Optional[dict[str, Any]]: - state_dict = self.model.state_dict() + def get_zmq_address(self): + """Get the ZMQ address for the current device.""" + return f"ipc:///tmp/{self.report_device_id()}.sock" - if self.is_generation_colocated: - # Collect info for streaming multiple tensors - self.refit_param_info = [] - for name, tensor in state_dict.items(): - # dtensor's numel will return complete tensor instead of only local tensor - size_in_bytes = tensor.element_size() * tensor.numel() - self.refit_param_info.append((name, size_in_bytes)) - - else: - # Collect info for collective communication - state_dict_info = {} - for name, tensor in state_dict.items(): - state_dict_info[name] = (tensor.shape, self.dtype) - - return state_dict_info + def maybe_init_zmq(self): + """Initialize the ZMQ socket if it doesn't exist.""" + if not hasattr(self, "zmq_socket"): + self.zmq_context = zmq.Context() + self.zmq_socket = self.zmq_context.socket(zmq.REQ) + self.zmq_socket.setsockopt(zmq.SNDTIMEO, 30000) # set timeout to 30 seconds + self.zmq_socket.setsockopt(zmq.RCVTIMEO, 30000) # set timeout to 30 seconds + self.zmq_socket.setsockopt(zmq.LINGER, 0) + self.zmq_socket.bind(self.get_zmq_address()) @torch.no_grad() - def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: - """Prepare the weights for IPC. + def prepare_refit_info(self) -> Optional[dict[str, Any]]: + """Prepare state dict metadata for weight refitting and IPC streaming.""" + state_dict_info = {} + for name, tensor in self.model.state_dict().items(): + # all tensor will be casted to self.dtype in stream_weights_via_ipc_zmq/broadcast_weights_for_collective + state_dict_info[name] = (tensor.shape, self.dtype) - This function: - - Prepares the state_dict of the model. - - Collects the info for streaming multiple tensors. + return state_dict_info - Returns: - list: The list of parameters sizes. - float: The total available memory in bytes. - """ + def get_free_memory_bytes(self) -> int: + """Get the available free memory.""" from nemo_rl.utils.nvml import get_free_memory_bytes + device_idx = torch.cuda.current_device() + return get_free_memory_bytes(device_idx) + + @torch.no_grad() + @wrap_with_nvtx_name("dtensor_policy_worker_v2/stream_weights_via_ipc_zmq") + def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int = 0) -> None: + """Stream model weights to peer process via ZMQ IPC socket.""" + self.maybe_init_zmq() # Manually move model to cuda for cpu offload case if self.cpu_offload: self.model = self.move_to_cuda(self.model) - # Get state_dict - self._held_sharded_state_dict_reference: dict[str, torch.Tensor] = ( - self.model.state_dict() - ) - - # Collect current available memory for refit - ## Get current device index from torch - device_idx = torch.cuda.current_device() - ## Get device free memory using NVML - total_available_bytes = get_free_memory_bytes(device_idx) - ## Use 80% of the free memory for safety - memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.8") - total_available_bytes *= float(memory_ratio) - - return self.refit_param_info, total_available_bytes - - @torch.no_grad() - @wrap_with_nvtx_name("dtensor_policy_worker_v2/get_weights_ipc_handles") - def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]: - assert self._held_sharded_state_dict_reference is not None, ( - "prepare_weights_for_ipc must be called before get_weights_ipc_handles" + from nemo_rl.models.policy.utils import stream_weights_via_ipc_zmq_impl + + def dtensor_params_generator(): + """Generator that yields (name, tensor) pairs, converting DTensors to local tensors.""" + for name, tensor in self.model.state_dict().items(): + if isinstance(tensor, DTensor): + # Convert DTensor to full tensor for streaming + full_tensor = tensor.full_tensor() + # Convert to target dtype + yield ( + name, + full_tensor.to(self.dtype, non_blocking=True).contiguous(), + ) + else: + # Convert to target dtype + yield name, tensor.to(self.dtype, non_blocking=True).contiguous() + + # Use the shared implementation + stream_weights_via_ipc_zmq_impl( + params_generator=dtensor_params_generator(), + buffer_size_bytes=buffer_size_bytes, + zmq_socket=self.zmq_socket, + rank=self.rank, + worker_name=str(self), ) - # Clean up the held tensors to reduce peak memory - if self._held_streamed_param_reference is not None: - del self._held_streamed_param_reference - self._held_streamed_param_reference = None - - converted_params = {} - for key in keys: - # Get full_tensor for dtensor (GPU > 1) - tensor = self._held_sharded_state_dict_reference[key] - if isinstance(tensor, DTensor): - full_tensor = tensor.full_tensor() - else: - full_tensor = tensor - # Convert parameters to the configured dtype - converted_params[key] = full_tensor.to(self.dtype, non_blocking=True) - - # Temporary record the full tensor for cleanup - # It is needed for cleanup the last full_tensor in the refit process - self._held_streamed_param_reference = converted_params - - # Get device UUID for IPC - device_uuid = self.report_device_id() - # Create handles for the tensors - all_handles = [] - for key, p in converted_params.items(): - handle = get_handle_from_tensor(p) - all_handles.append((key, handle)) - - # (pack_tensor_for_ipc: bool, handles: list) - serialized = (False, all_handles) - - return {device_uuid: serialized} - @torch.no_grad() def broadcast_weights_for_collective(self) -> None: """Broadcast the weights for collective communication.""" @@ -1842,23 +1806,12 @@ def offload_before_refit(self) -> None: @torch.no_grad() @wrap_with_nvtx_name("dtensor_policy_worker_v2/offload_after_refit") def offload_after_refit(self) -> None: - # Offload as much as possible on the CPU + """Offload as much as possible on the CPU.""" self.model = self.move_to_cpu(self.model) self.model.eval() torch.randn(1).cuda() # wake up torch allocator self.offload_before_refit() # rerun the old offload function - # Clean up the held tensors - if self._held_sharded_state_dict_reference is not None: - del self._held_sharded_state_dict_reference - self._held_sharded_state_dict_reference = None - if self._held_streamed_param_reference is not None: - del self._held_streamed_param_reference - self._held_streamed_param_reference = None - - gc.collect() - torch.cuda.empty_cache() - # Print memory stats after offloading allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB reserved = torch.cuda.memory_reserved() / (1024**3) # Convert to GB @@ -1948,6 +1901,10 @@ def load_checkpoint( def shutdown(self) -> None: """Shutdown the policy.""" + # Clean up extension resources like ZMQ sockets + if hasattr(self, "zmq_socket"): + self.zmq_socket.close() + self.zmq_context.term() def start_gpu_profiling(self) -> None: """Start GPU profiling.""" diff --git a/nemo_rl/models/policy/interfaces.py b/nemo_rl/models/policy/interfaces.py index fef56f0ea2..e221621403 100644 --- a/nemo_rl/models/policy/interfaces.py +++ b/nemo_rl/models/policy/interfaces.py @@ -158,11 +158,9 @@ def prepare_refit_info(self) -> Optional[dict[str, Any]]: pass @abstractmethod - def prepare_weights_for_ipc(self, *args: Any, **kwargs: Any) -> list[list[str]]: - pass - - @abstractmethod - def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]: + def stream_weights_via_ipc_zmq( + self, *args: Any, **kwargs: Any + ) -> list[ray.ObjectRef]: pass @abstractmethod diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 5d08003ad9..0a111c1e31 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -660,69 +660,19 @@ def prepare_refit_info(self) -> Optional[dict[str, Any]]: # Only get the first worker's info since all workers will have the same result return results[0] - def prepare_weights_for_ipc( - self, _refit_buffer_size_gb: Optional[int] = None - ) -> list[list[str]]: - """Prepare the weights for IPC. - - Returns: - list: A list containing the keys of the parameters, which is grouped by size. - """ - # Get the state_dict_info and available memory from all workers + def get_free_memory_bytes(self) -> int: + """Get the available free memory.""" + futures = self.worker_group.run_all_workers_single_data("get_free_memory_bytes") + # minimum free memory from all workers for safety + free_memory_bytes = min(ray.get(future) for future in futures) + return free_memory_bytes + + def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int) -> list[ray.ObjectRef]: + """Send the weights for IPC handles via ZMQ socket.""" futures = self.worker_group.run_all_workers_single_data( - "prepare_weights_for_ipc" - ) - results = ray.get(futures) - - # Only get the first worker's state_dict_info since all workers will have the same result - state_dict_info = results[0][0] - - if _refit_buffer_size_gb is not None: - total_available_bytes = _refit_buffer_size_gb * (1024**3) - else: - # Get the minimum available memory from all workers - total_available_bytes = min(result[1] for result in results) - - # Group tensors by size - cur_available_bytes = total_available_bytes - grouped_param_keys: list[list[str]] = [] - keys: list[str] = [] - - for key, size_in_bytes in state_dict_info: - if size_in_bytes > cur_available_bytes: - if keys: - grouped_param_keys.append(keys) - keys = [] - cur_available_bytes = total_available_bytes - - keys.append(key) - cur_available_bytes -= size_in_bytes - - if keys: - grouped_param_keys.append(keys) - - return grouped_param_keys - - def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]: - """Fetch weight IPC handles from all workers. - - Returns: - dict: A dictionary mapping device UUIDs to parameter IPC handles. - """ - # Collect IPC handles from all workers - worker_handles: list[dict[str, Any]] = ray.get( - [ - worker.get_weights_ipc_handles.remote(keys=keys) - for worker in self.worker_group.workers - ] + "stream_weights_via_ipc_zmq", buffer_size_bytes=buffer_size_bytes ) - - # Combine all worker handles into a single dictionary - all_handles = {} - for handle in worker_handles: - all_handles.update(handle) - - return all_handles + return futures def broadcast_weights_for_collective(self) -> list[ray.ObjectRef]: """Broadcast the weights for collective communication.""" @@ -793,7 +743,8 @@ def __del__(self) -> None: the object is lost due to leaving a function scope. It's always recommended that the user calls worker_group.shutdown(). """ - self.worker_group.shutdown() + if hasattr(self, "worker_group"): + self.worker_group.shutdown(cleanup_method="shutdown") def start_gpu_profiling(self) -> None: """Start GPU profiling.""" diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 2686e9c53a..1db9c4fa44 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -23,6 +23,7 @@ import ray import torch +import zmq from megatron.bridge import AutoBridge from megatron.bridge.models.model_provider import get_model from megatron.bridge.training import fault_tolerance @@ -125,7 +126,6 @@ from nemo_rl.models.policy.utils import ( configure_dynamo_cache, get_gpu_info, - get_handle_from_tensor, get_megatron_checkpoint_dir, get_runtime_env_for_policy_worker, ) @@ -1852,10 +1852,24 @@ def report_device_id(self) -> str: # Get device UUID using NVML return get_device_uuid(device_idx) + def get_zmq_address(self): + """Get the ZMQ address for the current device.""" + return f"ipc:///tmp/{self.report_device_id()}.sock" + + def maybe_init_zmq(self): + """Initialize the ZMQ socket if it doesn't exist.""" + if not hasattr(self, "zmq_socket"): + self.zmq_context = zmq.Context() + self.zmq_socket = self.zmq_context.socket(zmq.REQ) + self.zmq_socket.setsockopt(zmq.SNDTIMEO, 30000) # set timeout to 30 seconds + self.zmq_socket.setsockopt(zmq.RCVTIMEO, 30000) # set timeout to 30 seconds + self.zmq_socket.setsockopt(zmq.LINGER, 0) + self.zmq_socket.bind(self.get_zmq_address()) + @torch.no_grad() @wrap_with_nvtx_name("megatron_policy_worker/prepare_refit_info") def prepare_refit_info(self) -> None: - # Get parameter info for refit / mcore side info + """Prepare state dict metadata for weight refitting and IPC streaming.""" self.refit_param_info_mcore = self._calculate_refit_param_info() # Collect tensor metadata for refit / hf side info @@ -1863,12 +1877,10 @@ def prepare_refit_info(self) -> None: hf_params_generator = self.megatron_bridge.export_hf_weights( [self.model], show_progress=False, + conversion_tasks=self.refit_conversion_tasks, # used for metadata caching ) for name, tensor in hf_params_generator: - if self.is_generation_colocated: - metadata = (tensor.shape, tensor.dtype, tensor.numel()) - else: - metadata = (tensor.shape, tensor.dtype) + metadata = (tensor.shape, tensor.dtype) refit_param_info_hf[name] = metadata return refit_param_info_hf @@ -1923,120 +1935,36 @@ def calculate_size_in_bytes(param, tp_size, ep_size): ) return param_info - @wrap_with_nvtx_name("megatron_policy_worker/prepare_weights_for_ipc") - def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: - """Prepare Megatron model weights for IPC transfer to vLLM. - - Collects information about weight tensors (names and sizes). - Returns a list of (parameter_name, size_in_bytes) tuples. - """ + def get_free_memory_bytes(self) -> int: + """Get the available free memory.""" from nemo_rl.utils.nvml import get_free_memory_bytes - # Collect current available memory for refit - ## Get current device index from torch device_idx = torch.cuda.current_device() - ## Get device free memory using NVML - total_available_bytes = get_free_memory_bytes(device_idx) - ## default to 20% to get some more speedup than 10%, OOM if set to 30% - memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.2") - total_available_bytes *= float(memory_ratio) - self.refit_conversion_tasks_current_index = 0 - return self.refit_param_info_mcore, total_available_bytes - - # Temporary fix, 'keys' is a kwarg due to some sort of ray bug + return get_free_memory_bytes(device_idx) + @torch.no_grad() - @wrap_with_nvtx_name("megatron_policy_worker/get_weights_ipc_handles") - def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: - """Get IPC handles for the requested Megatron model weights. + @wrap_with_nvtx_name("megatron_policy_worker/stream_weights_via_ipc_zmq") + def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int = 0) -> None: + """Stream model weights to peer process via ZMQ IPC socket.""" + self.maybe_init_zmq() - Args: - keys: List of parameter names to get handles for - Returns: - Dict mapping device UUID to list of (mapped_key, handle) tuples - """ - if self._held_gather_buffer is not None: - del self._held_gather_buffer - self._held_gather_buffer = None - - # extract the conversion tasks in this pack - conversion_tasks = self.refit_conversion_tasks[ - self.refit_conversion_tasks_current_index : self.refit_conversion_tasks_current_index - + len(keys) - ] - self.refit_conversion_tasks_current_index += len(keys) + from nemo_rl.models.policy.utils import stream_weights_via_ipc_zmq_impl + # Generate HF parameters for streaming hf_params_generator = self.megatron_bridge.export_hf_weights( [self.model], show_progress=False, - conversion_tasks=conversion_tasks, + conversion_tasks=self.refit_conversion_tasks, # used for metadata caching ) - gathered_hf_params = {name: tensor for name, tensor in hf_params_generator} - - # Get device UUID for IPC handles - device_uuid = self.report_device_id() - # Create IPC handles for each parameter - tensor_number_threshold = os.getenv( - "NEMO_RL_MEGATRON_IPC_TENSOR_PACKING_THRESHOLD", "32" - ) # an arbitrary threshold - if len(gathered_hf_params) >= int(tensor_number_threshold): - pack_tensor_for_ipc = True - else: - pack_tensor_for_ipc = False - - if pack_tensor_for_ipc: - # Pack tensors in gathered_hf_params into consolidated tensors by dtype - # First calculate total size needed for each dtype - type_to_total_size = defaultdict(lambda: 0) - - # Record offset of the tensor - for key, tensor in gathered_hf_params.items(): - type_to_total_size[tensor.dtype] += tensor.numel() - - # Allocate consolidated tensors for each dtype - packed_tensors = { - dtype: torch.empty( - total_size, - device=next(iter(gathered_hf_params.values())).device, - dtype=dtype, - requires_grad=False, - ) - for dtype, total_size in type_to_total_size.items() - } - - dtype_to_offset = defaultdict(lambda: 0) - # Copy tensors into consolidated buffers - for key, tensor in gathered_hf_params.items(): - dtype = tensor.dtype - size = tensor.numel() - packed_tensors[dtype][ - dtype_to_offset[dtype] : dtype_to_offset[dtype] + size - ].copy_(tensor.detach().view(-1)) - dtype_to_offset[dtype] += size - - # Create IPC handles for consolidated tensors - all_handles = [ - (dtype, get_handle_from_tensor(tensor)) - for dtype, tensor in packed_tensors.items() - ] - - # Store reference to prevent garbage collection - self._held_gather_buffer = packed_tensors - - serialized = ( - pack_tensor_for_ipc, - all_handles, - tuple(gathered_hf_params.keys()), - ) - else: - all_handles = [] - for key, tensor in gathered_hf_params.items(): - handle = get_handle_from_tensor(tensor) - all_handles.append((key, handle)) - self._held_gather_buffer = gathered_hf_params - serialized = (False, all_handles) - - return {device_uuid: serialized} + # Use the shared implementation + stream_weights_via_ipc_zmq_impl( + params_generator=hf_params_generator, + buffer_size_bytes=buffer_size_bytes, + zmq_socket=self.zmq_socket, + rank=self.rank, + worker_name=str(self), + ) @torch.no_grad() def broadcast_weights_for_collective(self) -> None: @@ -2044,6 +1972,7 @@ def broadcast_weights_for_collective(self) -> None: hf_params_generator = self.megatron_bridge.export_hf_weights( [self.model], show_progress=False, + conversion_tasks=self.refit_conversion_tasks, # used for metadata caching ) # param_iterator will return (name, tensor), we only need tensor @@ -2116,9 +2045,8 @@ def offload_before_refit(self): # Move the tensor to CPU and update the state dictionary state[k] = v.to("cpu") - if self.cfg["megatron_cfg"]["empty_unused_memory_level"] >= 1: - gc.collect() - torch.cuda.empty_cache() + gc.collect() + torch.cuda.empty_cache() # Print memory stats after offloading allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB @@ -2130,22 +2058,14 @@ def offload_before_refit(self): @wrap_with_nvtx_name("megatron_policy_worker/offload_after_refit") def offload_after_refit(self): + """Offload as much as possible on the CPU.""" no_grad = torch.no_grad() no_grad.__enter__() - # Offload as much as possible on the CPU self.model = self.move_model(self.model, "cpu") self.model.eval() torch.randn(1).cuda() # wake up torch allocator self.offload_before_refit() # rerun the old offload function - if self._held_gather_buffer is not None: - del self._held_gather_buffer - self._held_gather_buffer = None - - if self.cfg["megatron_cfg"]["empty_unused_memory_level"] >= 1: - gc.collect() - torch.cuda.empty_cache() - allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB reserved = torch.cuda.memory_reserved() / (1024**3) # Convert to GB print( @@ -2296,7 +2216,10 @@ def load_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = Non def shutdown(self): """Shutdown the policy.""" - pass + # Clean up extension resources like ZMQ sockets + if hasattr(self, "zmq_socket"): + self.zmq_socket.close() + self.zmq_context.term() def start_gpu_profiling(self) -> None: """Start GPU profiling.""" diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index 52a8d4c054..b5a6c3a086 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import importlib import os +from enum import Enum from typing import Any, Dict import torch +from torch.multiprocessing.reductions import rebuild_cuda_tensor from transformers import ( AutoConfig, AutoModelForCausalLM, @@ -70,6 +73,13 @@ } +class IPCProtocol(Enum): + """IPC protocol constants for ZMQ weight streaming.""" + + COMPLETE = "complete" + ACK = "ack" + + def resolve_model_class(model_name: str) -> Any: """Resolve the appropriate model class for a given model name.""" if NEMO_AUTOMODEL_AVAILABLE: @@ -250,3 +260,148 @@ def get_handle_from_tensor(tensor: torch.Tensor) -> tuple[Any]: # skip serializing the function for better refit performance return reduce_tensor(tensor.detach())[1:] + + +def calculate_aligned_size(size_bytes: int, alignment: int = 512) -> int: + """Calculate aligned size for memory alignment. + + Args: + size_bytes(int): Size in bytes to align + alignment(int): Alignment boundary in bytes (default 512) + + Returns: + Aligned size in bytes(int). + """ + return int(((size_bytes + alignment - 1) // alignment) * alignment) + + +def stream_weights_via_ipc_zmq_impl( + params_generator, buffer_size_bytes: int, zmq_socket, rank: int, worker_name: str +) -> None: + """Shared implementation for streaming weights via IPC ZMQ with improved memory management. + + Uses ping-pong double buffering to enable overlapping communication while reusing buffers + to reduce memory allocation overhead and improve stability. + + Args: + params_generator: Generator yielding (name, tensor) pairs + buffer_size_bytes: total size of buffer in bytes for batching parameters + zmq_socket: ZMQ socket for communication + rank: Worker rank for logging + worker_name: Name of the worker for logging + """ + # Divide total buffer size by 2 because we use two individual buffers (ping-pong) for overlapping communication. + buffer_size_bytes = buffer_size_bytes // 2 + + def send_buffer_group_overlap(buffer, param_names, used_bytes, await_recv) -> bool: + """Send a group of parameters and return new pending_recv state.""" + # Synchronize before getting IPC handle to ensure data is ready + torch.cuda.current_stream().synchronize() + cuda_ipc_handle = get_handle_from_tensor(buffer) + + if await_recv: + zmq_socket.recv() + + # Payload tuple: (cuda_ipc_handle, param_names, used_bytes) + payload = (cuda_ipc_handle, param_names, used_bytes) + zmq_socket.send_pyobj(payload) + return True # pending_recv = True + + def allocate_buffer(device): + """Allocate a new aligned buffer with proper memory alignment.""" + aligned_size = calculate_aligned_size(buffer_size_bytes) + return torch.empty( + aligned_size, + device=device, + dtype=torch.uint8, + requires_grad=False, + ) + + def pack_tensor(buffer, tensor, used_bytes) -> int: + """Pack tensor into buffer and return new used_bytes.""" + tensor_bytes = tensor.nbytes + buffer[used_bytes : used_bytes + tensor_bytes].data.copy_( + tensor.data.view(-1).view(dtype=torch.uint8), non_blocking=True + ) + return used_bytes + calculate_aligned_size(tensor_bytes) + + # Initialize ping-pong double buffering + buffer_a: torch.Tensor | None = None + buffer_b: torch.Tensor | None = None + current_buffer: torch.Tensor | None = None + + used_bytes = 0 + param_names = [] + await_recv = False + count_of_groups = 0 + + try: + for name, tensor in params_generator: + # Initialize device and buffers on first tensor + if buffer_a is None: + buffer_a = allocate_buffer(tensor.device) + buffer_b = allocate_buffer(tensor.device) + current_buffer = buffer_a + + aligned_size = calculate_aligned_size(tensor.nbytes) + assert aligned_size <= buffer_size_bytes, ( + f"Parameter {name} too large for buffer: {aligned_size} > {buffer_size_bytes}" + ) + + # Check if we need to send current buffer and switch to the other one + if used_bytes + aligned_size > buffer_size_bytes: + await_recv = send_buffer_group_overlap( + current_buffer, param_names, used_bytes, await_recv + ) + count_of_groups += 1 + + # Switch buffers for ping-pong double buffering + current_buffer = buffer_b if current_buffer is buffer_a else buffer_a + used_bytes, param_names = 0, [] + + # Pack tensor into current buffer + param_names.append(name) + used_bytes = pack_tensor(current_buffer, tensor, used_bytes) + + # Send remaining tensors + if param_names: + await_recv = send_buffer_group_overlap( + current_buffer, param_names, used_bytes, await_recv + ) + count_of_groups += 1 + + # Complete transmission + if await_recv: + zmq_socket.recv() + + # Final synchronization and completion signal + torch.cuda.current_stream().synchronize() + zmq_socket.send_pyobj(IPCProtocol.COMPLETE) + zmq_socket.recv() + + if rank == 0: + print( + f"{worker_name}: Packed {count_of_groups} groups of tensors", flush=True + ) + + finally: + # Clean up buffers in finally block to ensure cleanup even on exceptions + if buffer_a is not None: + del buffer_a + if buffer_b is not None: + del buffer_b + + # Force garbage collection and clear CUDA cache + gc.collect() + torch.cuda.empty_cache() + + +def rebuild_cuda_tensor_from_ipc( + cuda_ipc_handle: tuple, device_id: int +) -> torch.Tensor: + """Rebuild a CUDA tensor from an IPC handle.""" + func = rebuild_cuda_tensor + args = cuda_ipc_handle[0] + list_args = list(args) + list_args[6] = device_id + return func(*list_args) diff --git a/pyproject.toml b/pyproject.toml index e7692f0bf1..8a0d1fc3d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ dependencies = [ "mlflow", "nvidia-nvshmem-cu12", # for deep_ep build "swanlab", + "pyzmq", ] [project.optional-dependencies] diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 478eac9a7c..84448dca41 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -1594,11 +1594,11 @@ def test_vllm_weight_update_and_prefix_cache_reset( ) print("Updating vLLM weights from HF policy...") - grouped_param_keys = lm_policy.prepare_weights_for_ipc() - for keys in grouped_param_keys: - ipc_handles = lm_policy.get_weights_ipc_handles(keys) - update_success = vllm_policy.update_weights_from_ipc_handles(ipc_handles) - assert update_success, "Weight update should succeed" + + buffer_size_bytes = int(lm_policy.get_free_memory_bytes() * 0.3) + lm_policy.stream_weights_via_ipc_zmq(buffer_size_bytes=buffer_size_bytes) + update_success = vllm_policy.update_weights_via_ipc_zmq() + assert update_success, "Weight update should succeed" print("vLLM weights successfully updated.") print("Running Generation 2 (Weights Updated, Cache Still Active)...") @@ -1675,7 +1675,7 @@ def test_vllm_weight_update_memory(cluster, tokenizer): lm_policy, vllm_policy, vllm_config["colocated"]["enabled"], - _refit_buffer_size_gb=1, + _refit_buffer_size_gb=1.5, ) gpu_infos = ray.get([w.get_gpu_info.remote() for w in workers]) @@ -2133,7 +2133,7 @@ def test_vllm_megatron_weight_update_memory(cluster, tokenizer): megatron_policy, vllm_policy, vllm_config["colocated"]["enabled"], - _refit_buffer_size_gb=1, + _refit_buffer_size_gb=1.5, ) gpu_infos = ray.get([w.get_gpu_info.remote() for w in workers]) diff --git a/tests/unit/models/policy/test_utils.py b/tests/unit/models/policy/test_utils.py index 8fb4d8f8b2..0b90ab0fbf 100644 --- a/tests/unit/models/policy/test_utils.py +++ b/tests/unit/models/policy/test_utils.py @@ -12,11 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import multiprocessing import os +import sys +import time +import traceback import unittest.mock +import pytest +import torch +import zmq + from nemo_rl.models.policy.utils import ( + IPCProtocol, + calculate_aligned_size, get_megatron_checkpoint_dir, + rebuild_cuda_tensor_from_ipc, + stream_weights_via_ipc_zmq_impl, ) @@ -106,3 +118,228 @@ def test_function_prints_selected_directory(self, capsys): f"Using default megatron checkpoint dir: {expected_dir}" in captured.out ) assert result == expected_dir + + +def server_process( + zmq_addr: str, + known_tensors: list[tuple[str, torch.Tensor]], + buffer_size_bytes: int, + ready_queue: multiprocessing.Queue, +) -> None: + """Server process that streams tensors via IPC ZMQ.""" + try: + device = torch.device("cuda:0") + gpu_tensors = [(name, tensor.to(device)) for name, tensor in known_tensors] + + context = zmq.Context() + socket = context.socket(zmq.PAIR) + socket.setsockopt(zmq.LINGER, 0) # Close immediately on error + socket.setsockopt(zmq.RCVTIMEO, 10000) # 10 second timeout + socket.bind(zmq_addr) + ready_queue.put(("ready", None)) + + stream_weights_via_ipc_zmq_impl( + (t for t in gpu_tensors), + buffer_size_bytes, + socket, + rank=0, + worker_name="test_server", + ) + except Exception as e: + import sys + import traceback + + error_details = f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()}" + ready_queue.put(("error", error_details)) + sys.exit( + 1 + ) # Exit with non-zero code so check_process_error detects the failure + finally: + socket.close() + context.term() + + +def client_process( + zmq_addr: str, + known_tensors_data: list[tuple[str, tuple, torch.dtype, torch.Tensor]], + result_queue: multiprocessing.Queue, +) -> None: + """Client process that receives and validates tensors via IPC ZMQ.""" + try: + device = torch.device("cuda:0") + + # Prepare expected tensors on GPU + expected_tensors = { + name: tensor.to(device) for name, _, _, tensor in known_tensors_data + } + state_dict_info = { + name: (shape, dtype) for name, shape, dtype, _ in known_tensors_data + } + + context = zmq.Context() + socket = context.socket(zmq.PAIR) + socket.setsockopt(zmq.LINGER, 0) # Close immediately on error + socket.setsockopt(zmq.RCVTIMEO, 10000) # 10 second timeout + socket.connect(zmq_addr) + + # Receive and validate loop + while True: + payload = socket.recv_pyobj() + if payload == IPCProtocol.COMPLETE: + socket.send(IPCProtocol.ACK.value.encode()) + break + + ipc_handle, list_keys, used_bytes = payload + buffer = rebuild_cuda_tensor_from_ipc(ipc_handle, device.index) + + offset = 0 + for key in list_keys: + shape, dtype = state_dict_info[key] + shape = torch.Size(shape) if isinstance(shape, list) else shape + size_in_bytes = dtype.itemsize * shape.numel() + + tensor = ( + buffer[offset : offset + size_in_bytes] + .view(dtype=dtype) + .view(shape) + ) + expected = expected_tensors[key] + + # Validate tensor + assert tensor.shape == expected.shape, f"Shape mismatch for {key}" + assert tensor.dtype == expected.dtype, f"Dtype mismatch for {key}" + assert torch.allclose(tensor, expected, rtol=1e-7, atol=1e-7), ( + f"Values mismatch for {key}" + ) + + offset += calculate_aligned_size(size_in_bytes) + + assert offset == used_bytes, f"Offset mismatch: {offset} != {used_bytes}" + socket.send(b"") + + result_queue.put(("success", "All tensors validated")) + except Exception as e: + error_details = f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()}" + result_queue.put(("error", error_details)) + sys.exit(1) + finally: + socket.close() + context.term() + + +def check_process_error( + proc: multiprocessing.Process, + queue: multiprocessing.Queue, + process_name: str, +) -> None: + """Check if a process failed and assert with detailed error message if available.""" + if proc.exitcode == 0: + return + + # Get error details from queue + error_msg = None + while not queue.empty(): + status, msg = queue.get_nowait() + if status == "error": + error_msg = msg + break + + if proc.exitcode is None: + assert False, f"{process_name} timed out" + else: + details = f"\n{error_msg}" if error_msg else "" + assert False, f"{process_name} failed (exitcode={proc.exitcode}){details}" + + +class TestStreamWeightsViaIPC: + """Test suite for IPC weight streaming functionality.""" + + TIMEOUT = 30 # 30 second timeout for additional overhead when running with coverage + + @pytest.mark.parametrize( + "test_case,tensor_specs,buffer_size_bytes,test_description", + [ + ( + "large_buffer", + [ + ("tensor_1", (10, 20), torch.float32), # 0.78KB + ("tensor_2", (5, 15, 25), torch.float32), # 7.32KB + ("tensor_3", (100,), torch.float16), # 0.20KB + ("tensor_4", (50, 50), torch.bfloat16), # 4.88KB + ("tensor_5", (8, 16, 32), torch.float32), # 16.00KB + ], # Total: 29.18KB + 100 * 1024, # 100 KB - large buffer for single batch (50KB per side) + "Test with various shapes/dtypes in large buffer (single batch)", + ), + ( + "small_buffer", + [ + ("small_1", (30, 30), torch.float32), # 3.52KB + ("small_2", (20, 40), torch.float16), # 1.56KB + ("small_3", (128,), torch.float32), # 0.50KB + ("small_4", (25, 35), torch.float32), # 3.42KB + ], # Total: 9.00KB + 10 * 1024, # 10 KB - forces multiple batches (5KB per side) + "Test with small buffer forcing multiple batches", + ), + ], + ) + def test_stream_weights_via_ipc_zmq_impl( + self, test_case, tensor_specs, buffer_size_bytes, test_description + ): + """Test streaming weights via IPC ZMQ between server and client processes.""" + # Generate test tensors + known_tensors = [ + (name, torch.randn(*shape, dtype=dtype)) + for name, shape, dtype in tensor_specs + ] + known_tensors_data = [ + (name, list(t.shape), t.dtype, t) for name, t in known_tensors + ] + + # Create unique socket path and queues + socket_path = f"/tmp/test_ipc_zmq_{test_case}_{os.getpid()}_{time.time()}" + zmq_addr = f"ipc://{socket_path}" + + mp_context = multiprocessing.get_context("spawn") + ready_queue = mp_context.Queue() + result_queue = mp_context.Queue() + + # Start server and client + server_proc = mp_context.Process( + target=server_process, + args=(zmq_addr, known_tensors, buffer_size_bytes, ready_queue), + ) + server_proc.start() + + status, msg = ready_queue.get(timeout=self.TIMEOUT) + assert status == "ready", f"Server failed: {msg}" + + client_proc = mp_context.Process( + target=client_process, + args=(zmq_addr, known_tensors_data, result_queue), + ) + client_proc.start() + + # Wait and validate + try: + server_proc.join(timeout=self.TIMEOUT) + client_proc.join(timeout=self.TIMEOUT) + + # Check client first since client failure often causes server to fail + check_process_error(client_proc, result_queue, "Client") + check_process_error(server_proc, ready_queue, "Server") + + # Verify client success message + status, msg = result_queue.get(timeout=self.TIMEOUT) + assert status == "success", f"Validation failed: {msg}" + finally: + for proc in [server_proc, client_proc]: + if proc and proc.is_alive(): + proc.terminate() + proc.join(timeout=self.TIMEOUT) + if proc.is_alive(): + proc.kill() + + if os.path.exists(socket_path): + os.unlink(socket_path) diff --git a/uv.lock b/uv.lock index 2019ca2b41..4b34110cb2 100644 --- a/uv.lock +++ b/uv.lock @@ -3134,6 +3134,7 @@ dependencies = [ { name = "omegaconf" }, { name = "pillow" }, { name = "plotly" }, + { name = "pyzmq" }, { name = "ray", extra = ["default"] }, { name = "rich" }, { name = "setuptools" }, @@ -3252,6 +3253,7 @@ requires-dist = [ { name = "penguin", marker = "extra == 'penguin'", editable = "3rdparty/Penguin-workspace" }, { name = "pillow", specifier = ">=11.3.0" }, { name = "plotly" }, + { name = "pyzmq" }, { name = "ray", extras = ["default"], specifier = "==2.46.0" }, { name = "rich" }, { name = "setuptools" },