From 9d2814022aedba44c1e5f5b91611659f7171fd32 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Fri, 3 Oct 2025 02:15:40 -0700 Subject: [PATCH 01/20] feat: refit refactoring with zmq and overlapping Signed-off-by: Zhiyu Li --- nemo_rl/algorithms/grpo.py | 28 ++- nemo_rl/models/generation/interfaces.py | 2 +- .../models/generation/vllm/vllm_backend.py | 170 +++++++++--------- .../models/generation/vllm/vllm_generation.py | 36 ++-- nemo_rl/models/generation/vllm/vllm_worker.py | 41 +---- .../generation/vllm/vllm_worker_async.py | 6 +- .../models/policy/dtensor_policy_worker.py | 140 +++++---------- nemo_rl/models/policy/interfaces.py | 8 +- nemo_rl/models/policy/lm_policy.py | 72 ++------ .../models/policy/megatron_policy_worker.py | 154 ++++------------ nemo_rl/models/policy/utils.py | 132 ++++++++++++++ pyproject.toml | 1 + uv.lock | 11 ++ 13 files changed, 351 insertions(+), 450 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 6d2ec56dc4..cbc3f5ba52 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -748,22 +748,20 @@ def refit_policy_generation( update_success = False if colocated_inference: # get model param keys, which is grouped by size - grouped_param_keys = policy.prepare_weights_for_ipc( - _refit_buffer_size_gb=_refit_buffer_size_gb - ) - total_num_keys = sum(len(k) for k in grouped_param_keys) - print( - f"[Refit] Split {total_num_keys} keys into {len(grouped_param_keys)} groups", - flush=True, + if _refit_buffer_size_gb is not None: + buffer_size_bytes = _refit_buffer_size_gb * (1024**3) + else: + memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.15") + buffer_size_bytes = policy.get_free_memory_bytes() * float(memory_ratio) + + futures_train = policy.stream_weights_via_ipc_zmq( + buffer_size_bytes=buffer_size_bytes ) - # do update - for keys in grouped_param_keys: - ipc_handles = policy.get_weights_ipc_handles(keys) - update_success = policy_generation.update_weights_from_ipc_handles( - ipc_handles - ) - if not update_success: - break + futures_inference = policy_generation.update_weights_via_ipc_zmq() + # wait for all futures to complete + ray.get(futures_train) + results = ray.get(futures_inference) + update_success = all(result for result in results if result is not None) else: # update weights through nccl futures_train = policy.broadcast_weights_for_collective() diff --git a/nemo_rl/models/generation/interfaces.py b/nemo_rl/models/generation/interfaces.py index d424c7c1df..12e5aecbc1 100644 --- a/nemo_rl/models/generation/interfaces.py +++ b/nemo_rl/models/generation/interfaces.py @@ -233,7 +233,7 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: """Prepare the info for refit.""" raise NotImplementedError - def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool: + def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]: """Update the model weights from the given IPC handles.""" raise NotImplementedError diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 14db473ae9..6abbbb8546 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import defaultdict -from typing import Any, Optional +import gc +from typing import Any import torch from torch.multiprocessing.reductions import rebuild_cuda_tensor +import zmq from nemo_rl.utils.nsys import wrap_with_nvtx_name from nemo_rl.utils.packed_tensor import packed_broadcast_consumer @@ -60,117 +61,106 @@ def report_device_id(self) -> str: return get_device_uuid(self.device.index) - def prepare_refit_info( - self, state_dict_info: Optional[dict[str, Any]] = None - ) -> None: + def get_zmq_address(self): + """Get the ZMQ address for the current device.""" + return f"ipc:///{self.report_device_id()}.sock" + + def maybe_init_zmq(self): + """Initialize the ZMQ socket if it doesn't exist.""" + if not hasattr(self, "zmq_socket"): + self.zmq_context = zmq.Context() # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored + self.zmq_socket = self.zmq_context.socket( # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored + zmq.REP + ) + # Set receive timeout to 30 seconds to avoid hanging indefinitely + self.zmq_socket.setsockopt( + zmq.RCVTIMEO, 30000 + ) # 30 seconds in milliseconds + self.zmq_socket.connect(self.get_zmq_address()) + + def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: """Prepare the info for refit. - DtensorPolicyWorker: - colocated inference: state_dict_info is None - non-colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype)} - - MegatronPolicyWorker: - colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype, numel)} - non-colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype)} + Args: + state_dict_info (dict): A dictionary containing the info for refit. + e.g. {tensor_name: (shape, dtype)} """ self.state_dict_info = state_dict_info # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored - @wrap_with_nvtx_name( - "vllm_internal_worker_extension/update_weights_from_global_ipc_handles" - ) - def update_weights_from_global_ipc_handles(self, global_device_ipc_handles): - """Update weights from global IPC handles. + @wrap_with_nvtx_name("vllm_internal_worker_extension/update_weights_via_ipc_zmq") + def update_weights_via_ipc_zmq(self) -> bool: + """Update weights from local IPC handles via ZMQ socket. Args: - global_device_ipc_handles (dict): Dictionary mapping device UUIDs to parameter IPC handles. + None Returns: bool: True if weights were successfully updated. """ - device_uuid = self.report_device_id() - local_device_ipc_handles = global_device_ipc_handles[device_uuid] - return self.update_weights_from_local_ipc_handles(local_device_ipc_handles) + buffer = None + weights = None - @wrap_with_nvtx_name( - "vllm_internal_worker_extension/update_weights_from_local_ipc_handles" - ) - def update_weights_from_local_ipc_handles(self, local_device_ipc_handles): - """Update weights from local IPC handles. - - Args: - local_device_ipc_handles (dict): parameter IPC handles for local device. - - Returns: - bool: True if weights were successfully updated. - """ try: - is_tensor_packed = local_device_ipc_handles[0] - if is_tensor_packed: - _, all_handles, list_keys = local_device_ipc_handles - else: - _, name_and_handle_list = local_device_ipc_handles - - device_id = self.device.index - weights = [] - - if is_tensor_packed: - assert self.state_dict_info is not None, ( - "state_dict_info is not prepared. " - "Please call prepare_refit_info when initializing the worker." - ) - - # Extract packed tensor from IPC handle - dtype_to_packed_tensor = {} - for dtype, tensor_handle in all_handles: - func = rebuild_cuda_tensor - args = tensor_handle[0] - list_args = list(args) - list_args[6] = device_id - tensor = func(*list_args) - dtype_to_packed_tensor[dtype] = tensor + self.maybe_init_zmq() + from nemo_rl.models.policy.utils import calculate_aligned_size + + while True: + # Blocking receive with timeout (this is the main operation) + payload = self.zmq_socket.recv_pyobj() + + if payload == "complete": + # means the update is done + self.zmq_socket.send(b"") + break + + packed_tensor_handle, list_keys, used_bytes = payload + device_id = self.device.index + func = rebuild_cuda_tensor + args = packed_tensor_handle[0] + list_args = list(args) + list_args[6] = device_id + buffer = func(*list_args) weights = [] - dtype_to_offset = defaultdict(lambda: 0) + offset = 0 for key in list_keys: - shape, dtype, size = self.state_dict_info[key] + shape, dtype = self.state_dict_info[key] # pyrefly + if isinstance(shape, list): + shape = torch.Size(shape) + size_in_bytes = dtype.itemsize * shape.numel() weights.append( ( key, - dtype_to_packed_tensor[dtype][ - dtype_to_offset[dtype] : dtype_to_offset[dtype] + size - ].view(*shape), + buffer[offset : offset + size_in_bytes] + .view(dtype=dtype) + .view(shape), ) ) - dtype_to_offset[dtype] += size - - expected_sizes = { - dtype: tensor.numel() - for dtype, tensor in dtype_to_packed_tensor.items() - } - assert dtype_to_offset == expected_sizes, ( - f"Packed tensor size mismatch: expected sizes from keys list {expected_sizes} != actual packed tensor sizes {dtype_to_offset}. " - f"This indicates the keys list order doesn't match the order used when packing tensors." + aligned_size = calculate_aligned_size(size_in_bytes) + offset += aligned_size + assert offset == used_bytes, ( + "Offset is not equal to used bytes, usually indicate key info inaccurate like dtype" ) - else: - # Process each handle to get the tensor - for name, handle in name_and_handle_list: - func = rebuild_cuda_tensor - args = handle[0] - list_args = list(args) - list_args[6] = device_id - tensor = func(*list_args) - weights.append((name, tensor)) - - # Load weights into the model - from nemo_rl.models.generation import fp8 - - if fp8.is_fp8_model(self.model_runner.vllm_config): - # the fp8 load_weights additionally casts bf16 weights into fp8 - fp8.load_weights(weights, self.model_runner) - else: - self.model_runner.model.load_weights(weights=weights) - + # Load weights into the model + from nemo_rl.models.generation import fp8 + + if fp8.is_fp8_model(self.model_runner.vllm_config): + # the fp8 load_weights additionally casts bf16 weights into fp8 + fp8.load_weights(weights, self.model_runner) + else: + self.model_runner.model.load_weights(weights=weights) + + torch.cuda.current_stream().synchronize() + self.zmq_socket.send(b"") + + if buffer is not None: + del buffer + if weights is not None: + del weights + gc.collect() + torch.cuda.empty_cache() return True + except Exception as e: print( f"Error in VllmInternalWorkerExtension.update_weights_from_ipc_handles: {e}" diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index 29a222ba29..09ff2f1f6a 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -763,7 +763,7 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: # Wait for all futures to complete ray.get(futures) - def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool: + def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]: """Update weights of the policy using IPC handles, considering tensor parallelism. For tp > 1, only the leader in each tensor parallel tied worker group will update weights. @@ -775,37 +775,23 @@ def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool: bool: True if weights were successfully updated, False otherwise. """ if not self.worker_group or not self.worker_group.workers: - return False + raise RuntimeError("Worker group is not initialized") # Choose the appropriate method based on async_engine setting method_name = ( - "update_weights_from_ipc_handles_async" + "update_weights_via_ipc_zmq_async" if self.cfg["vllm_cfg"]["async_engine"] - else "update_weights_from_ipc_handles" + else "update_weights_via_ipc_zmq" ) - # Only send the ipc handles required by the current worker - ipc_handles_list = [] - for worker_device_uuids in self.device_uuids: - worker_ipc_handles = { - device_uuid: ipc_handles[device_uuid] - for device_uuid in worker_device_uuids - } - ipc_handles_list.append(worker_ipc_handles) + # Use run_all_workers_single_data since no data needs to be passed + futures = self.worker_group.run_all_workers_single_data( + method_name, + run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], + ) - try: - # Directly pass ipc_handles to the method - futures = self.worker_group.run_all_workers_multiple_data( - method_name, - ipc_handles=ipc_handles_list, - run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], - ) - # Wait for all futures to complete - results = ray.get(futures) - return all(result for result in results if result is not None) - except Exception as e: - print(f"Error during update weights: {e}") - return False + # this function should co-work with lm_policy, so we should wait for all futures to complete outside + return futures def update_weights_from_collective(self) -> list[ray.ObjectRef]: """Update weights of the policy using collective communication.""" diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index 78d8505632..62e680f125 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -707,11 +707,11 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,)) @wrap_with_nvtx_name("vllm_genertion_worker/update_weights_from_ipc_handles") - def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool: - """Update weights from IPC handles by delegating to the vLLM Worker implementation. + def update_weights_via_ipc_zmq(self) -> bool: + """Update weights from IPC handles via ZMQ socket. Args: - ipc_handles (dict): Dictionary mapping device UUIDs (str) to parameter IPC handles. + None Returns: bool: True if weights were successfully updated, False otherwise. @@ -726,37 +726,10 @@ def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool: "update_weights_from_ipc_handles cannot be used with async_engine=True. Use update_weights_from_ipc_handles_async instead." ) - if self.tensor_parallel_size == 1: - # UniProcExecutor - assert len(self.vllm_device_ids) == 1 - result_or_coro = self.llm.collective_rpc( - "update_weights_from_local_ipc_handles", - args=(ipc_handles[self.vllm_device_ids[0]],), - ) - else: - """ - DO NOT USE VLLM's collective_rpc: This code causes duplicate IPC data transfer across Ray workers, - leading to unnecessary network serialization overhead and potential performance degradation. - - result_or_coro = self.llm.collective_rpc( - "update_weights_from_global_ipc_handles", args=(ipc_handles,) - ) - """ - ray_worker_outputs = [] - # MultiProcExecutor - for worker, device_id in zip( - self.llm.llm_engine.model_executor.workers, self.vllm_device_ids - ): - ray_worker_outputs.append( - worker.execute_method.remote( - "update_weights_from_local_ipc_handles", - ipc_handles[device_id], - ) - ) - - # Gather the results - result_or_coro = ray.get(ray_worker_outputs) - + result_or_coro = self.llm.collective_rpc( + "update_weights_via_ipc_zmq", + args=tuple(), + ) worker_result = result_or_coro[0] if not worker_result: diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index b01db28aab..80b5e0a1c2 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -850,8 +850,8 @@ async def prepare_refit_info_async(self, state_dict_info: dict[str, Any]) -> Non """Async version of prepare_refit_info.""" await self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,)) - async def update_weights_from_ipc_handles_async( - self, ipc_handles: dict[str, Any] + async def update_weights_via_ipc_zmq_async( + self, ) -> bool: """Async version of update_weights_from_ipc_handles. @@ -873,7 +873,7 @@ async def update_weights_from_ipc_handles_async( # TODO: switch to update_weights_from_local_ipc_handles for better performance once collectively report_device_id is supported in asyncLLM initialization result_or_coro = await self.llm.collective_rpc( - "update_weights_from_global_ipc_handles", args=(ipc_handles,) + "update_weights_via_ipc_zmq", args=tuple() ) if asyncio.iscoroutine(result_or_coro): diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index c2f2f5a006..18201f9ef1 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -45,6 +45,7 @@ ) from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM +import zmq from nemo_rl.algorithms.interfaces import LossFunction, LossType from nemo_rl.algorithms.loss_functions import SequencePackingLossWrapper from nemo_rl.distributed.batched_data_dict import BatchedDataDict @@ -72,7 +73,6 @@ from nemo_rl.models.policy.utils import ( configure_dynamo_cache, get_gpu_info, - get_handle_from_tensor, get_runtime_env_for_policy_worker, import_class_from_path, resolve_model_class, @@ -178,9 +178,6 @@ def __init__( # with different order of node_bundles configure_dynamo_cache() - # vars used for refit - ## will be initialized in prepare_refit_info - self.refit_param_info = None ## used for streaming update inference engine weights self._held_sharded_state_dict_reference: Optional[dict[str, torch.Tensor]] = ( None @@ -1704,100 +1701,60 @@ def report_device_id(self) -> str: # Get device UUID using NVML return get_device_uuid(device_idx) - @torch.no_grad() - def prepare_refit_info(self) -> Optional[dict[str, Any]]: - state_dict = self.model.state_dict() + def get_zmq_address(self): + """Get the ZMQ address for the current device.""" + return f"ipc:///{self.report_device_id()}.sock" - if self.is_generation_colocated: - # Collect info for streaming multiple tensors - self.refit_param_info = [] - for name, tensor in state_dict.items(): - # dtensor's numel will return complete tensor instead of only local tensor - size_in_bytes = tensor.element_size() * tensor.numel() - self.refit_param_info.append((name, size_in_bytes)) - - else: - # Collect info for collective communication - state_dict_info = {} - for name, tensor in state_dict.items(): - state_dict_info[name] = (tensor.shape, self.dtype) - - return state_dict_info + def maybe_init_zmq(self): + """Initialize the ZMQ socket if it doesn't exist.""" + if not hasattr(self, "zmq_socket"): + self.zmq_context = zmq.Context() + self.zmq_socket = self.zmq_context.socket(zmq.REQ) + self.zmq_socket.bind(self.get_zmq_address()) @torch.no_grad() - def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: - """Prepare the weights for IPC. + def prepare_refit_info(self) -> Optional[dict[str, Any]]: + state_dict_info = {} + for name, tensor in self.model.state_dict().items(): + state_dict_info[name] = (tensor.shape, self.dtype) - This function: - - Prepares the state_dict of the model. - - Collects the info for streaming multiple tensors. + return state_dict_info - Returns: - list: The list of parameters sizes. - float: The total available memory in bytes. - """ + def get_free_memory_bytes(self) -> int: + """Get the available free memory.""" from nemo_rl.utils.nvml import get_free_memory_bytes - # Manually move model to cuda for cpu offload case - if self.cpu_offload: - self.model = self.move_to_cuda(self.model) - - # Get state_dict - self._held_sharded_state_dict_reference: dict[str, torch.Tensor] = ( - self.model.state_dict() - ) - - # Collect current available memory for refit - ## Get current device index from torch device_idx = torch.cuda.current_device() - ## Get device free memory using NVML - total_available_bytes = get_free_memory_bytes(device_idx) - ## Use 80% of the free memory for safety - memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.8") - total_available_bytes *= float(memory_ratio) - - return self.refit_param_info, total_available_bytes + return get_free_memory_bytes(device_idx) @torch.no_grad() - @wrap_with_nvtx_name("dtensor_policy_worker/get_weights_ipc_handles") - def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]: - assert self._held_sharded_state_dict_reference is not None, ( - "prepare_weights_for_ipc must be called before get_weights_ipc_handles" + @wrap_with_nvtx_name("dtensor_policy_worker/stream_weights_via_ipc_zmq") + def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int = 0) -> None: + self.maybe_init_zmq() + + from nemo_rl.models.policy.utils import stream_weights_via_ipc_zmq_impl + + def dtensor_params_generator(): + """Generator that yields (name, tensor) pairs, converting DTensors to local tensors.""" + for name, tensor in self.model.state_dict().items(): + if isinstance(tensor, DTensor): + # Convert DTensor to full tensor for streaming + full_tensor = tensor.full_tensor() + # Convert to target dtype + yield name, full_tensor.to(self.dtype, non_blocking=True) + else: + # Convert to target dtype + yield name, tensor.to(self.dtype, non_blocking=True) + + # Use the shared implementation + stream_weights_via_ipc_zmq_impl( + params_generator=dtensor_params_generator(), + buffer_size_bytes=buffer_size_bytes, + zmq_socket=self.zmq_socket, + rank=self.rank, + worker_name=str(self), ) - # Clean up the held tensors to reduce peak memory - if self._held_streamed_param_reference is not None: - del self._held_streamed_param_reference - self._held_streamed_param_reference = None - - converted_params = {} - for key in keys: - # Get full_tensor for dtensor (GPU > 1) - tensor = self._held_sharded_state_dict_reference[key] - if isinstance(tensor, DTensor): - full_tensor = tensor.full_tensor() - else: - full_tensor = tensor - # Convert parameters to the configured dtype - converted_params[key] = full_tensor.to(self.dtype, non_blocking=True) - - # Temporary record the full tensor for cleanup - # It is needed for cleanup the last full_tensor in the refit process - self._held_streamed_param_reference = converted_params - - # Get device UUID for IPC - device_uuid = self.report_device_id() - # Create handles for the tensors - all_handles = [] - for key, p in converted_params.items(): - handle = get_handle_from_tensor(p) - all_handles.append((key, handle)) - - # (pack_tensor_for_ipc: bool, handles: list) - serialized = (False, all_handles) - - return {device_uuid: serialized} - @torch.no_grad() def broadcast_weights_for_collective(self) -> None: """Broadcast the weights for collective communication.""" @@ -1887,17 +1844,6 @@ def offload_after_refit(self) -> None: torch.randn(1).cuda() # wake up torch allocator self.offload_before_refit() # rerun the old offload function - # Clean up the held tensors - if self._held_sharded_state_dict_reference is not None: - del self._held_sharded_state_dict_reference - self._held_sharded_state_dict_reference = None - if self._held_streamed_param_reference is not None: - del self._held_streamed_param_reference - self._held_streamed_param_reference = None - - gc.collect() - torch.cuda.empty_cache() - # Print memory stats after offloading allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB reserved = torch.cuda.memory_reserved() / (1024**3) # Convert to GB diff --git a/nemo_rl/models/policy/interfaces.py b/nemo_rl/models/policy/interfaces.py index fef56f0ea2..e221621403 100644 --- a/nemo_rl/models/policy/interfaces.py +++ b/nemo_rl/models/policy/interfaces.py @@ -158,11 +158,9 @@ def prepare_refit_info(self) -> Optional[dict[str, Any]]: pass @abstractmethod - def prepare_weights_for_ipc(self, *args: Any, **kwargs: Any) -> list[list[str]]: - pass - - @abstractmethod - def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]: + def stream_weights_via_ipc_zmq( + self, *args: Any, **kwargs: Any + ) -> list[ray.ObjectRef]: pass @abstractmethod diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 5d08003ad9..8d7d6c3b16 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -660,69 +660,19 @@ def prepare_refit_info(self) -> Optional[dict[str, Any]]: # Only get the first worker's info since all workers will have the same result return results[0] - def prepare_weights_for_ipc( - self, _refit_buffer_size_gb: Optional[int] = None - ) -> list[list[str]]: - """Prepare the weights for IPC. - - Returns: - list: A list containing the keys of the parameters, which is grouped by size. - """ - # Get the state_dict_info and available memory from all workers + def get_free_memory_bytes(self) -> int: + """Get the available free memory.""" + futures = self.worker_group.run_all_workers_single_data("get_free_memory_bytes") + # minimum free memory from all workers for safety + free_memory_bytes = min(ray.get(future) for future in futures) + return free_memory_bytes + + def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int) -> list[ray.ObjectRef]: + """Send the weights for IPC handles via ZMQ socket.""" futures = self.worker_group.run_all_workers_single_data( - "prepare_weights_for_ipc" - ) - results = ray.get(futures) - - # Only get the first worker's state_dict_info since all workers will have the same result - state_dict_info = results[0][0] - - if _refit_buffer_size_gb is not None: - total_available_bytes = _refit_buffer_size_gb * (1024**3) - else: - # Get the minimum available memory from all workers - total_available_bytes = min(result[1] for result in results) - - # Group tensors by size - cur_available_bytes = total_available_bytes - grouped_param_keys: list[list[str]] = [] - keys: list[str] = [] - - for key, size_in_bytes in state_dict_info: - if size_in_bytes > cur_available_bytes: - if keys: - grouped_param_keys.append(keys) - keys = [] - cur_available_bytes = total_available_bytes - - keys.append(key) - cur_available_bytes -= size_in_bytes - - if keys: - grouped_param_keys.append(keys) - - return grouped_param_keys - - def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]: - """Fetch weight IPC handles from all workers. - - Returns: - dict: A dictionary mapping device UUIDs to parameter IPC handles. - """ - # Collect IPC handles from all workers - worker_handles: list[dict[str, Any]] = ray.get( - [ - worker.get_weights_ipc_handles.remote(keys=keys) - for worker in self.worker_group.workers - ] + "stream_weights_via_ipc_zmq", buffer_size_bytes=buffer_size_bytes ) - - # Combine all worker handles into a single dictionary - all_handles = {} - for handle in worker_handles: - all_handles.update(handle) - - return all_handles + return futures def broadcast_weights_for_collective(self) -> list[ray.ObjectRef]: """Broadcast the weights for collective communication.""" diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 2686e9c53a..492b9778ed 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -97,6 +97,7 @@ from ray.util.queue import Queue from transformers import PreTrainedTokenizerBase +import zmq from nemo_rl.algorithms.interfaces import LossFunction, LossType from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( @@ -125,7 +126,6 @@ from nemo_rl.models.policy.utils import ( configure_dynamo_cache, get_gpu_info, - get_handle_from_tensor, get_megatron_checkpoint_dir, get_runtime_env_for_policy_worker, ) @@ -1852,6 +1852,17 @@ def report_device_id(self) -> str: # Get device UUID using NVML return get_device_uuid(device_idx) + def get_zmq_address(self): + """Get the ZMQ address for the current device.""" + return f"ipc:///{self.report_device_id()}.sock" + + def maybe_init_zmq(self): + """Initialize the ZMQ socket if it doesn't exist.""" + if not hasattr(self, "zmq_socket"): + self.zmq_context = zmq.Context() + self.zmq_socket = self.zmq_context.socket(zmq.REQ) + self.zmq_socket.bind(self.get_zmq_address()) + @torch.no_grad() @wrap_with_nvtx_name("megatron_policy_worker/prepare_refit_info") def prepare_refit_info(self) -> None: @@ -1863,12 +1874,10 @@ def prepare_refit_info(self) -> None: hf_params_generator = self.megatron_bridge.export_hf_weights( [self.model], show_progress=False, + conversion_tasks=self.refit_conversion_tasks, # used for metadata caching ) for name, tensor in hf_params_generator: - if self.is_generation_colocated: - metadata = (tensor.shape, tensor.dtype, tensor.numel()) - else: - metadata = (tensor.shape, tensor.dtype) + metadata = (tensor.shape, tensor.dtype) refit_param_info_hf[name] = metadata return refit_param_info_hf @@ -1923,120 +1932,35 @@ def calculate_size_in_bytes(param, tp_size, ep_size): ) return param_info - @wrap_with_nvtx_name("megatron_policy_worker/prepare_weights_for_ipc") - def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: - """Prepare Megatron model weights for IPC transfer to vLLM. - - Collects information about weight tensors (names and sizes). - Returns a list of (parameter_name, size_in_bytes) tuples. - """ + def get_free_memory_bytes(self) -> int: + """Get the available free memory.""" from nemo_rl.utils.nvml import get_free_memory_bytes - # Collect current available memory for refit - ## Get current device index from torch device_idx = torch.cuda.current_device() - ## Get device free memory using NVML - total_available_bytes = get_free_memory_bytes(device_idx) - ## default to 20% to get some more speedup than 10%, OOM if set to 30% - memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.2") - total_available_bytes *= float(memory_ratio) - self.refit_conversion_tasks_current_index = 0 - return self.refit_param_info_mcore, total_available_bytes - - # Temporary fix, 'keys' is a kwarg due to some sort of ray bug + return get_free_memory_bytes(device_idx) + @torch.no_grad() - @wrap_with_nvtx_name("megatron_policy_worker/get_weights_ipc_handles") - def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: - """Get IPC handles for the requested Megatron model weights. + @wrap_with_nvtx_name("megatron_policy_worker/stream_weights_via_ipc_zmq") + def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int = 0) -> None: + self.maybe_init_zmq() - Args: - keys: List of parameter names to get handles for - Returns: - Dict mapping device UUID to list of (mapped_key, handle) tuples - """ - if self._held_gather_buffer is not None: - del self._held_gather_buffer - self._held_gather_buffer = None - - # extract the conversion tasks in this pack - conversion_tasks = self.refit_conversion_tasks[ - self.refit_conversion_tasks_current_index : self.refit_conversion_tasks_current_index - + len(keys) - ] - self.refit_conversion_tasks_current_index += len(keys) + from nemo_rl.models.policy.utils import stream_weights_via_ipc_zmq_impl + # Generate HF parameters for streaming hf_params_generator = self.megatron_bridge.export_hf_weights( [self.model], show_progress=False, - conversion_tasks=conversion_tasks, + conversion_tasks=self.refit_conversion_tasks, # used for metadata caching ) - gathered_hf_params = {name: tensor for name, tensor in hf_params_generator} - - # Get device UUID for IPC handles - device_uuid = self.report_device_id() - - # Create IPC handles for each parameter - tensor_number_threshold = os.getenv( - "NEMO_RL_MEGATRON_IPC_TENSOR_PACKING_THRESHOLD", "32" - ) # an arbitrary threshold - if len(gathered_hf_params) >= int(tensor_number_threshold): - pack_tensor_for_ipc = True - else: - pack_tensor_for_ipc = False - - if pack_tensor_for_ipc: - # Pack tensors in gathered_hf_params into consolidated tensors by dtype - # First calculate total size needed for each dtype - type_to_total_size = defaultdict(lambda: 0) - - # Record offset of the tensor - for key, tensor in gathered_hf_params.items(): - type_to_total_size[tensor.dtype] += tensor.numel() - - # Allocate consolidated tensors for each dtype - packed_tensors = { - dtype: torch.empty( - total_size, - device=next(iter(gathered_hf_params.values())).device, - dtype=dtype, - requires_grad=False, - ) - for dtype, total_size in type_to_total_size.items() - } - - dtype_to_offset = defaultdict(lambda: 0) - # Copy tensors into consolidated buffers - for key, tensor in gathered_hf_params.items(): - dtype = tensor.dtype - size = tensor.numel() - packed_tensors[dtype][ - dtype_to_offset[dtype] : dtype_to_offset[dtype] + size - ].copy_(tensor.detach().view(-1)) - dtype_to_offset[dtype] += size - - # Create IPC handles for consolidated tensors - all_handles = [ - (dtype, get_handle_from_tensor(tensor)) - for dtype, tensor in packed_tensors.items() - ] - # Store reference to prevent garbage collection - self._held_gather_buffer = packed_tensors - - serialized = ( - pack_tensor_for_ipc, - all_handles, - tuple(gathered_hf_params.keys()), - ) - else: - all_handles = [] - for key, tensor in gathered_hf_params.items(): - handle = get_handle_from_tensor(tensor) - all_handles.append((key, handle)) - self._held_gather_buffer = gathered_hf_params - serialized = (False, all_handles) - - return {device_uuid: serialized} + # Use the shared implementation + stream_weights_via_ipc_zmq_impl( + params_generator=hf_params_generator, + buffer_size_bytes=buffer_size_bytes, + zmq_socket=self.zmq_socket, + rank=self.rank, + worker_name=str(self), + ) @torch.no_grad() def broadcast_weights_for_collective(self) -> None: @@ -2044,6 +1968,7 @@ def broadcast_weights_for_collective(self) -> None: hf_params_generator = self.megatron_bridge.export_hf_weights( [self.model], show_progress=False, + conversion_tasks=self.refit_conversion_tasks, # used for metadata caching ) # param_iterator will return (name, tensor), we only need tensor @@ -2116,9 +2041,8 @@ def offload_before_refit(self): # Move the tensor to CPU and update the state dictionary state[k] = v.to("cpu") - if self.cfg["megatron_cfg"]["empty_unused_memory_level"] >= 1: - gc.collect() - torch.cuda.empty_cache() + gc.collect() + torch.cuda.empty_cache() # Print memory stats after offloading allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB @@ -2138,14 +2062,6 @@ def offload_after_refit(self): torch.randn(1).cuda() # wake up torch allocator self.offload_before_refit() # rerun the old offload function - if self._held_gather_buffer is not None: - del self._held_gather_buffer - self._held_gather_buffer = None - - if self.cfg["megatron_cfg"]["empty_unused_memory_level"] >= 1: - gc.collect() - torch.cuda.empty_cache() - allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB reserved = torch.cuda.memory_reserved() / (1024**3) # Convert to GB print( diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index 52a8d4c054..a2a47fd56f 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import importlib import os from typing import Any, Dict @@ -250,3 +251,134 @@ def get_handle_from_tensor(tensor: torch.Tensor) -> tuple[Any]: # skip serializing the function for better refit performance return reduce_tensor(tensor.detach())[1:] + + +def calculate_aligned_size(size_bytes: int, alignment: int = 512) -> int: + """Calculate aligned size for memory alignment. + + Args: + size_bytes: Size in bytes to align + alignment: Alignment boundary in bytes (default 512) + + Returns: + Aligned size in bytes + """ + return int(((size_bytes + alignment - 1) // alignment) * alignment) + + +def stream_weights_via_ipc_zmq_impl( + params_generator, buffer_size_bytes: int, zmq_socket, rank: int, worker_name: str +) -> None: + """Shared implementation for streaming weights via IPC ZMQ with improved memory management. + + Uses double buffering to enable overlapping communication while reusing buffers + to reduce memory allocation overhead and improve stability. + + Args: + params_generator: Generator yielding (name, tensor) pairs + buffer_size_bytes: Size of buffer in bytes for batching parameters + zmq_socket: ZMQ socket for communication + rank: Worker rank for logging + worker_name: Name of the worker for logging + """ + + def send_buffer_group_overlap(buffer, param_names, used_bytes, await_recv) -> bool: + """Send a group of parameters and return new pending_recv state.""" + # Synchronize before getting IPC handle to ensure data is ready + torch.cuda.current_stream().synchronize() + cuda_ipc_handle = get_handle_from_tensor(buffer) + + if await_recv: + zmq_socket.recv() + + serialized = (cuda_ipc_handle, tuple(param_names), used_bytes) + zmq_socket.send_pyobj(serialized) + return True # pending_recv = True + + def allocate_buffer(device): + """Allocate a new aligned buffer with proper memory alignment.""" + aligned_size = calculate_aligned_size(buffer_size_bytes) + return torch.empty( + aligned_size, + device=device, + dtype=torch.uint8, + requires_grad=False, + ) + + def pack_tensor(buffer, tensor, used_bytes) -> int: + """Pack tensor into buffer and return new used_bytes.""" + tensor_bytes = tensor.nbytes + buffer[used_bytes : used_bytes + tensor_bytes].data.copy_( + tensor.data.view(-1).view(dtype=torch.uint8), non_blocking=True + ) + return used_bytes + calculate_aligned_size(tensor_bytes) + + # Initialize double buffering system + buffer_a: torch.Tensor | None = None + buffer_b: torch.Tensor | None = None + current_buffer: torch.Tensor | None = None + + used_bytes = 0 + param_names = [] + await_recv = False + count_of_groups = 0 + + try: + for name, tensor in params_generator: + # Initialize device and buffers on first tensor + if buffer_a is None: + buffer_a = allocate_buffer(tensor.device) + buffer_b = allocate_buffer(tensor.device) + current_buffer = buffer_a + + aligned_size = calculate_aligned_size(tensor.nbytes) + assert aligned_size <= buffer_size_bytes, ( + f"Parameter {name} too large for buffer: {aligned_size} > {buffer_size_bytes}" + ) + + # Check if we need to send current buffer and switch to the other one + if used_bytes + aligned_size > buffer_size_bytes: + await_recv = send_buffer_group_overlap( + current_buffer, param_names, used_bytes, await_recv + ) + count_of_groups += 1 + + # Switch buffers for double buffering + current_buffer = buffer_b if current_buffer is buffer_a else buffer_a + used_bytes, param_names = 0, [] + + # Pack tensor into current buffer + param_names.append(name) + used_bytes = pack_tensor(current_buffer, tensor, used_bytes) + + # Send remaining tensors + if param_names: + await_recv = send_buffer_group_overlap( + current_buffer, param_names, used_bytes, await_recv + ) + count_of_groups += 1 + + # Complete transmission + if await_recv: + zmq_socket.recv() + + # Final synchronization and completion signal + torch.cuda.current_stream().synchronize() + zmq_socket.send_pyobj("complete") + zmq_socket.recv() + + if rank == 0: + print( + f"{worker_name}: Packed {count_of_groups} groups of tensors", flush=True + ) + + finally: + # Clean up buffers in finally block to ensure cleanup even on exceptions + if buffer_a is not None: + del buffer_a + if buffer_b is not None: + del buffer_b + + # Force garbage collection and clear CUDA cache + gc.collect() + torch.cuda.empty_cache() diff --git a/pyproject.toml b/pyproject.toml index e7692f0bf1..be6e366eec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ dependencies = [ "mlflow", "nvidia-nvshmem-cu12", # for deep_ep build "swanlab", + "zmq", ] [project.optional-dependencies] diff --git a/uv.lock b/uv.lock index 2019ca2b41..9dc8232056 100644 --- a/uv.lock +++ b/uv.lock @@ -3151,6 +3151,7 @@ dependencies = [ { name = "triton", version = "3.3.1", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform != 'darwin'" }, { name = "triton", version = "3.4.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin'" }, { name = "wandb" }, + { name = "zmq" }, ] [package.optional-dependencies] @@ -3272,6 +3273,7 @@ requires-dist = [ { name = "vllm", marker = "extra == 'mcore'", specifier = "==0.10.0" }, { name = "vllm", marker = "extra == 'vllm'", specifier = "==0.10.0" }, { name = "wandb" }, + { name = "zmq" }, ] provides-extras = ["automodel", "vllm", "mcore", "penguin"] @@ -7095,3 +7097,12 @@ sdist = { url = "https://files.pythonhosted.org/packages/e3/02/0f2892c661036d50e wheels = [ { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276, upload-time = "2025-06-08T17:06:38.034Z" }, ] + +[[package]] +name = "zmq" +version = "0.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyzmq" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6e/78/833b2808793c1619835edb1a4e17a023d5d625f4f97ff25ffff986d1f472/zmq-0.0.0.tar.gz", hash = "sha256:6b1a1de53338646e8c8405803cffb659e8eb7bb02fff4c9be62a7acfac8370c9", size = 966, upload-time = "2015-05-21T17:34:26.603Z" } From 21f3df69eb00b1d6b15bc855f95452cc59a2f4fd Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Fri, 3 Oct 2025 10:10:34 -0700 Subject: [PATCH 02/20] address issue Signed-off-by: Zhiyu Li --- nemo_rl/algorithms/grpo.py | 4 +- .../models/generation/vllm/vllm_backend.py | 6 +- .../models/generation/vllm/vllm_generation.py | 11 +- nemo_rl/models/generation/vllm/vllm_worker.py | 4 +- .../generation/vllm/vllm_worker_async.py | 4 +- .../models/policy/dtensor_policy_worker.py | 21 ++- .../models/policy/dtensor_policy_worker_v2.py | 154 ++++++------------ .../models/policy/megatron_policy_worker.py | 2 +- 8 files changed, 74 insertions(+), 132 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index cbc3f5ba52..b35406c42c 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -752,7 +752,9 @@ def refit_policy_generation( buffer_size_bytes = _refit_buffer_size_gb * (1024**3) else: memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.15") - buffer_size_bytes = policy.get_free_memory_bytes() * float(memory_ratio) + buffer_size_bytes = int( + policy.get_free_memory_bytes() * float(memory_ratio) + ) futures_train = policy.stream_weights_via_ipc_zmq( buffer_size_bytes=buffer_size_bytes diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 6abbbb8546..a78b55fd49 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -63,7 +63,7 @@ def report_device_id(self) -> str: def get_zmq_address(self): """Get the ZMQ address for the current device.""" - return f"ipc:///{self.report_device_id()}.sock" + return f"ipc:///tmp/{self.report_device_id()}.sock" def maybe_init_zmq(self): """Initialize the ZMQ socket if it doesn't exist.""" @@ -71,7 +71,7 @@ def maybe_init_zmq(self): self.zmq_context = zmq.Context() # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored self.zmq_socket = self.zmq_context.socket( # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored zmq.REP - ) + ) # Set receive timeout to 30 seconds to avoid hanging indefinitely self.zmq_socket.setsockopt( zmq.RCVTIMEO, 30000 @@ -139,7 +139,7 @@ def update_weights_via_ipc_zmq(self) -> bool: aligned_size = calculate_aligned_size(size_in_bytes) offset += aligned_size assert offset == used_bytes, ( - "Offset is not equal to used bytes, usually indicate key info inaccurate like dtype" + "Offset is not equal to used bytes, usually indicate inaccurate info like keys or cached dtype in state_dict_info" ) # Load weights into the model from nemo_rl.models.generation import fp8 diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index 09ff2f1f6a..f43adf07cf 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -764,16 +764,7 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: ray.get(futures) def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]: - """Update weights of the policy using IPC handles, considering tensor parallelism. - - For tp > 1, only the leader in each tensor parallel tied worker group will update weights. - - Args: - ipc_handles (dict): Dictionary mapping device UUIDs (str) to parameter IPC handles. - - Returns: - bool: True if weights were successfully updated, False otherwise. - """ + """Update weights of the policy using IPC handles via ZMQ socket.""" if not self.worker_group or not self.worker_group.workers: raise RuntimeError("Worker group is not initialized") diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index 62e680f125..85e8cf8b49 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -706,7 +706,7 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: """Prepare the info for refit.""" self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,)) - @wrap_with_nvtx_name("vllm_genertion_worker/update_weights_from_ipc_handles") + @wrap_with_nvtx_name("vllm_genertion_worker/update_weights_via_ipc_zmq") def update_weights_via_ipc_zmq(self) -> bool: """Update weights from IPC handles via ZMQ socket. @@ -723,7 +723,7 @@ def update_weights_via_ipc_zmq(self) -> bool: if self.cfg["vllm_cfg"]["async_engine"]: raise RuntimeError( - "update_weights_from_ipc_handles cannot be used with async_engine=True. Use update_weights_from_ipc_handles_async instead." + "update_weights_via_ipc_zmq cannot be used with async_engine=True. Use update_weights_via_ipc_zmq_async instead." ) result_or_coro = self.llm.collective_rpc( diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index 80b5e0a1c2..bc28f45b86 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -856,7 +856,7 @@ async def update_weights_via_ipc_zmq_async( """Async version of update_weights_from_ipc_handles. Args: - ipc_handles (dict): Dictionary mapping device UUIDs (str) to parameter IPC handles. + None Returns: bool: True if weights were successfully updated, False otherwise. @@ -868,7 +868,7 @@ async def update_weights_via_ipc_zmq_async( if not self.cfg["vllm_cfg"]["async_engine"]: raise RuntimeError( - "update_weights_from_ipc_handles_async can only be used with async_engine=True. Use update_weights_from_ipc_handles instead." + "update_weights_via_ipc_zmq_async can only be used with async_engine=True. Use update_weights_via_ipc_zmq instead." ) # TODO: switch to update_weights_from_local_ipc_handles for better performance once collectively report_device_id is supported in asyncLLM initialization diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 18201f9ef1..e931ee89b6 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -178,12 +178,6 @@ def __init__( # with different order of node_bundles configure_dynamo_cache() - ## used for streaming update inference engine weights - self._held_sharded_state_dict_reference: Optional[dict[str, torch.Tensor]] = ( - None - ) - self._held_streamed_param_reference: Optional[dict[str, torch.Tensor]] = None - self.cfg = config # torch distributed init. Envars for rank, world_size, and master_addr and master_port are set from the ray remote call torch.distributed.init_process_group(backend="nccl") @@ -1703,19 +1697,25 @@ def report_device_id(self) -> str: def get_zmq_address(self): """Get the ZMQ address for the current device.""" - return f"ipc:///{self.report_device_id()}.sock" + return f"ipc:///tmp/{self.report_device_id()}.sock" def maybe_init_zmq(self): """Initialize the ZMQ socket if it doesn't exist.""" if not hasattr(self, "zmq_socket"): self.zmq_context = zmq.Context() self.zmq_socket = self.zmq_context.socket(zmq.REQ) + self.zmq_socket.setsockopt(zmq.SNDTIMEO, 30000) # 30s + self.zmq_socket.setsockopt(zmq.RCVTIMEO, 30000) # 30s + self.zmq_socket.setsockopt(zmq.LINGER, 0) self.zmq_socket.bind(self.get_zmq_address()) @torch.no_grad() def prepare_refit_info(self) -> Optional[dict[str, Any]]: state_dict_info = {} for name, tensor in self.model.state_dict().items(): + assert tensor.dtype == self.dtype, ( + f"Tensor {name} has dtype {tensor.dtype} but expected {self.dtype}" + ) state_dict_info[name] = (tensor.shape, self.dtype) return state_dict_info @@ -1741,10 +1741,13 @@ def dtensor_params_generator(): # Convert DTensor to full tensor for streaming full_tensor = tensor.full_tensor() # Convert to target dtype - yield name, full_tensor.to(self.dtype, non_blocking=True) + yield ( + name, + full_tensor.to(self.dtype, non_blocking=True).contiguous(), + ) else: # Convert to target dtype - yield name, tensor.to(self.dtype, non_blocking=True) + yield name, tensor.to(self.dtype, non_blocking=True).contiguous() # Use the shared implementation stream_weights_via_ipc_zmq_impl( diff --git a/nemo_rl/models/policy/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/dtensor_policy_worker_v2.py index fbd394bf80..8965329ed2 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/dtensor_policy_worker_v2.py @@ -18,7 +18,7 @@ import warnings from collections import defaultdict from contextlib import AbstractContextManager, contextmanager, nullcontext -from typing import Any, Generator, Iterable, Optional, cast +from typing import Any, Generator, Optional, cast import ray import torch @@ -62,6 +62,7 @@ ) from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM +import zmq from nemo_rl.algorithms.interfaces import LossFunction, LossType from nemo_rl.algorithms.loss_functions import SequencePackingLossWrapper from nemo_rl.distributed.batched_data_dict import BatchedDataDict @@ -83,7 +84,6 @@ from nemo_rl.models.policy.utils import ( configure_dynamo_cache, get_gpu_info, - get_handle_from_tensor, get_runtime_env_for_policy_worker, import_class_from_path, resolve_model_class, @@ -449,15 +449,6 @@ def __init__( "No weights path provided. Starting from scratch (default policy init)" ) - # vars used for refit - ## will be initialized in prepare_refit_info - self.refit_param_info = None - ## used for streaming update inference engine weights - self._held_sharded_state_dict_reference: Optional[dict[str, torch.Tensor]] = ( - None - ) - self._held_streamed_param_reference: Optional[dict[str, torch.Tensor]] = None - def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor: if "generation" in self.cfg and self.cfg["generation"] is not None: logits.div_(self.cfg["generation"]["temperature"]) @@ -1665,100 +1656,66 @@ def report_device_id(self) -> str: # Get device UUID using NVML return get_device_uuid(device_idx) - @torch.no_grad() - def prepare_refit_info(self) -> Optional[dict[str, Any]]: - state_dict = self.model.state_dict() - - if self.is_generation_colocated: - # Collect info for streaming multiple tensors - self.refit_param_info = [] - for name, tensor in state_dict.items(): - # dtensor's numel will return complete tensor instead of only local tensor - size_in_bytes = tensor.element_size() * tensor.numel() - self.refit_param_info.append((name, size_in_bytes)) - - else: - # Collect info for collective communication - state_dict_info = {} - for name, tensor in state_dict.items(): - state_dict_info[name] = (tensor.shape, self.dtype) + def get_zmq_address(self): + """Get the ZMQ address for the current device.""" + return f"ipc:///tmp/{self.report_device_id()}.sock" - return state_dict_info + def maybe_init_zmq(self): + """Initialize the ZMQ socket if it doesn't exist.""" + if not hasattr(self, "zmq_socket"): + self.zmq_context = zmq.Context() + self.zmq_socket = self.zmq_context.socket(zmq.REQ) + self.zmq_socket.setsockopt(zmq.SNDTIMEO, 30000) # 30s + self.zmq_socket.setsockopt(zmq.RCVTIMEO, 30000) # 30s + self.zmq_socket.setsockopt(zmq.LINGER, 0) + self.zmq_socket.bind(self.get_zmq_address()) @torch.no_grad() - def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]: - """Prepare the weights for IPC. + def prepare_refit_info(self) -> Optional[dict[str, Any]]: + state_dict_info = {} + for name, tensor in self.model.state_dict().items(): + assert tensor.dtype == self.dtype, ( + f"Tensor {name} has dtype {tensor.dtype} but expected {self.dtype}" + ) + state_dict_info[name] = (tensor.shape, self.dtype) - This function: - - Prepares the state_dict of the model. - - Collects the info for streaming multiple tensors. + return state_dict_info - Returns: - list: The list of parameters sizes. - float: The total available memory in bytes. - """ + def get_free_memory_bytes(self) -> int: + """Get the available free memory.""" from nemo_rl.utils.nvml import get_free_memory_bytes - # Manually move model to cuda for cpu offload case - if self.cpu_offload: - self.model = self.move_to_cuda(self.model) - - # Get state_dict - self._held_sharded_state_dict_reference: dict[str, torch.Tensor] = ( - self.model.state_dict() - ) - - # Collect current available memory for refit - ## Get current device index from torch device_idx = torch.cuda.current_device() - ## Get device free memory using NVML - total_available_bytes = get_free_memory_bytes(device_idx) - ## Use 80% of the free memory for safety - memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.8") - total_available_bytes *= float(memory_ratio) - - return self.refit_param_info, total_available_bytes + return get_free_memory_bytes(device_idx) @torch.no_grad() - @wrap_with_nvtx_name("dtensor_policy_worker_v2/get_weights_ipc_handles") - def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]: - assert self._held_sharded_state_dict_reference is not None, ( - "prepare_weights_for_ipc must be called before get_weights_ipc_handles" + @wrap_with_nvtx_name("dtensor_policy_worker_v2/stream_weights_via_ipc_zmq") + def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int = 0) -> None: + self.maybe_init_zmq() + + from nemo_rl.models.policy.utils import stream_weights_via_ipc_zmq_impl + + def dtensor_params_generator(): + """Generator that yields (name, tensor) pairs, converting DTensors to local tensors.""" + for name, tensor in self.model.state_dict().items(): + if isinstance(tensor, DTensor): + # Convert DTensor to full tensor for streaming + full_tensor = tensor.full_tensor() + # Convert to target dtype + yield name, full_tensor.to(self.dtype, non_blocking=True) + else: + # Convert to target dtype + yield name, tensor.to(self.dtype, non_blocking=True) + + # Use the shared implementation + stream_weights_via_ipc_zmq_impl( + params_generator=dtensor_params_generator(), + buffer_size_bytes=buffer_size_bytes, + zmq_socket=self.zmq_socket, + rank=self.rank, + worker_name=str(self), ) - # Clean up the held tensors to reduce peak memory - if self._held_streamed_param_reference is not None: - del self._held_streamed_param_reference - self._held_streamed_param_reference = None - - converted_params = {} - for key in keys: - # Get full_tensor for dtensor (GPU > 1) - tensor = self._held_sharded_state_dict_reference[key] - if isinstance(tensor, DTensor): - full_tensor = tensor.full_tensor() - else: - full_tensor = tensor - # Convert parameters to the configured dtype - converted_params[key] = full_tensor.to(self.dtype, non_blocking=True) - - # Temporary record the full tensor for cleanup - # It is needed for cleanup the last full_tensor in the refit process - self._held_streamed_param_reference = converted_params - - # Get device UUID for IPC - device_uuid = self.report_device_id() - # Create handles for the tensors - all_handles = [] - for key, p in converted_params.items(): - handle = get_handle_from_tensor(p) - all_handles.append((key, handle)) - - # (pack_tensor_for_ipc: bool, handles: list) - serialized = (False, all_handles) - - return {device_uuid: serialized} - @torch.no_grad() def broadcast_weights_for_collective(self) -> None: """Broadcast the weights for collective communication.""" @@ -1848,17 +1805,6 @@ def offload_after_refit(self) -> None: torch.randn(1).cuda() # wake up torch allocator self.offload_before_refit() # rerun the old offload function - # Clean up the held tensors - if self._held_sharded_state_dict_reference is not None: - del self._held_sharded_state_dict_reference - self._held_sharded_state_dict_reference = None - if self._held_streamed_param_reference is not None: - del self._held_streamed_param_reference - self._held_streamed_param_reference = None - - gc.collect() - torch.cuda.empty_cache() - # Print memory stats after offloading allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB reserved = torch.cuda.memory_reserved() / (1024**3) # Convert to GB diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 492b9778ed..7fdb9b4dda 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1854,7 +1854,7 @@ def report_device_id(self) -> str: def get_zmq_address(self): """Get the ZMQ address for the current device.""" - return f"ipc:///{self.report_device_id()}.sock" + return f"ipc:///tmp/{self.report_device_id()}.sock" def maybe_init_zmq(self): """Initialize the ZMQ socket if it doesn't exist.""" From 7fdf6f698bbcbea02007a35a4a889469e74097cb Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Fri, 3 Oct 2025 11:29:26 -0700 Subject: [PATCH 03/20] address comments and lint Signed-off-by: Zhiyu Li --- nemo_rl/models/generation/vllm/vllm_backend.py | 2 +- nemo_rl/models/policy/dtensor_policy_worker.py | 6 ++---- nemo_rl/models/policy/dtensor_policy_worker_v2.py | 10 ++++------ nemo_rl/models/policy/megatron_policy_worker.py | 2 +- pyproject.toml | 2 +- uv.lock | 13 ++----------- 6 files changed, 11 insertions(+), 24 deletions(-) diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index a78b55fd49..7c59276e73 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -15,9 +15,9 @@ from typing import Any import torch +import zmq from torch.multiprocessing.reductions import rebuild_cuda_tensor -import zmq from nemo_rl.utils.nsys import wrap_with_nvtx_name from nemo_rl.utils.packed_tensor import packed_broadcast_consumer diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index e931ee89b6..2fec0ca819 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -23,6 +23,7 @@ import ray import torch +import zmq from accelerate import init_empty_weights from torch import nn from torch.distributed.checkpoint.state_dict import ( @@ -45,7 +46,6 @@ ) from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM -import zmq from nemo_rl.algorithms.interfaces import LossFunction, LossType from nemo_rl.algorithms.loss_functions import SequencePackingLossWrapper from nemo_rl.distributed.batched_data_dict import BatchedDataDict @@ -1713,9 +1713,7 @@ def maybe_init_zmq(self): def prepare_refit_info(self) -> Optional[dict[str, Any]]: state_dict_info = {} for name, tensor in self.model.state_dict().items(): - assert tensor.dtype == self.dtype, ( - f"Tensor {name} has dtype {tensor.dtype} but expected {self.dtype}" - ) + # all tensor will be casted to self.dtype in stream_weights_via_ipc_zmq/broadcast_weights_for_collective state_dict_info[name] = (tensor.shape, self.dtype) return state_dict_info diff --git a/nemo_rl/models/policy/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/dtensor_policy_worker_v2.py index 8965329ed2..a765ad243a 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/dtensor_policy_worker_v2.py @@ -22,6 +22,7 @@ import ray import torch +import zmq from accelerate import init_empty_weights from nemo_automodel import ( NeMoAutoModelForSequenceClassification, @@ -62,7 +63,6 @@ ) from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM -import zmq from nemo_rl.algorithms.interfaces import LossFunction, LossType from nemo_rl.algorithms.loss_functions import SequencePackingLossWrapper from nemo_rl.distributed.batched_data_dict import BatchedDataDict @@ -1674,9 +1674,7 @@ def maybe_init_zmq(self): def prepare_refit_info(self) -> Optional[dict[str, Any]]: state_dict_info = {} for name, tensor in self.model.state_dict().items(): - assert tensor.dtype == self.dtype, ( - f"Tensor {name} has dtype {tensor.dtype} but expected {self.dtype}" - ) + # all tensor will be casted to self.dtype in stream_weights_via_ipc_zmq/broadcast_weights_for_collective state_dict_info[name] = (tensor.shape, self.dtype) return state_dict_info @@ -1702,10 +1700,10 @@ def dtensor_params_generator(): # Convert DTensor to full tensor for streaming full_tensor = tensor.full_tensor() # Convert to target dtype - yield name, full_tensor.to(self.dtype, non_blocking=True) + yield name, full_tensor.to(self.dtype, non_blocking=True).contiguous() else: # Convert to target dtype - yield name, tensor.to(self.dtype, non_blocking=True) + yield name, tensor.to(self.dtype, non_blocking=True).contiguous() # Use the shared implementation stream_weights_via_ipc_zmq_impl( diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 7fdb9b4dda..1ae912ae2c 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -23,6 +23,7 @@ import ray import torch +import zmq from megatron.bridge import AutoBridge from megatron.bridge.models.model_provider import get_model from megatron.bridge.training import fault_tolerance @@ -97,7 +98,6 @@ from ray.util.queue import Queue from transformers import PreTrainedTokenizerBase -import zmq from nemo_rl.algorithms.interfaces import LossFunction, LossType from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( diff --git a/pyproject.toml b/pyproject.toml index be6e366eec..8a0d1fc3d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ dependencies = [ "mlflow", "nvidia-nvshmem-cu12", # for deep_ep build "swanlab", - "zmq", + "pyzmq", ] [project.optional-dependencies] diff --git a/uv.lock b/uv.lock index 9dc8232056..4b34110cb2 100644 --- a/uv.lock +++ b/uv.lock @@ -3134,6 +3134,7 @@ dependencies = [ { name = "omegaconf" }, { name = "pillow" }, { name = "plotly" }, + { name = "pyzmq" }, { name = "ray", extra = ["default"] }, { name = "rich" }, { name = "setuptools" }, @@ -3151,7 +3152,6 @@ dependencies = [ { name = "triton", version = "3.3.1", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform != 'darwin'" }, { name = "triton", version = "3.4.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin'" }, { name = "wandb" }, - { name = "zmq" }, ] [package.optional-dependencies] @@ -3253,6 +3253,7 @@ requires-dist = [ { name = "penguin", marker = "extra == 'penguin'", editable = "3rdparty/Penguin-workspace" }, { name = "pillow", specifier = ">=11.3.0" }, { name = "plotly" }, + { name = "pyzmq" }, { name = "ray", extras = ["default"], specifier = "==2.46.0" }, { name = "rich" }, { name = "setuptools" }, @@ -3273,7 +3274,6 @@ requires-dist = [ { name = "vllm", marker = "extra == 'mcore'", specifier = "==0.10.0" }, { name = "vllm", marker = "extra == 'vllm'", specifier = "==0.10.0" }, { name = "wandb" }, - { name = "zmq" }, ] provides-extras = ["automodel", "vllm", "mcore", "penguin"] @@ -7097,12 +7097,3 @@ sdist = { url = "https://files.pythonhosted.org/packages/e3/02/0f2892c661036d50e wheels = [ { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276, upload-time = "2025-06-08T17:06:38.034Z" }, ] - -[[package]] -name = "zmq" -version = "0.0.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyzmq" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/6e/78/833b2808793c1619835edb1a4e17a023d5d625f4f97ff25ffff986d1f472/zmq-0.0.0.tar.gz", hash = "sha256:6b1a1de53338646e8c8405803cffb659e8eb7bb02fff4c9be62a7acfac8370c9", size = 966, upload-time = "2015-05-21T17:34:26.603Z" } From cd759224a956d8fe372b467c47de9fdfed896c57 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Tue, 7 Oct 2025 19:06:18 -0700 Subject: [PATCH 04/20] fix Signed-off-by: Zhiyu Li --- nemo_rl/models/generation/vllm/vllm_backend.py | 13 +++++++++---- nemo_rl/models/policy/dtensor_policy_worker_v2.py | 5 ++++- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 7c59276e73..9368c1868e 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -151,12 +151,17 @@ def update_weights_via_ipc_zmq(self) -> bool: self.model_runner.model.load_weights(weights=weights) torch.cuda.current_stream().synchronize() + + # CRITICAL: Delete views before ACK to prevent corruption. + # 'weights' contains views into IPC shared memory. Even though load_weights() + # copied the data, Python may not garbage collect these view objects immediately. + # If sender reuses the buffer before GC runs, old views would read corrupted data. + # Explicit del ensures immediate cleanup before sending ACK. + del weights, buffer + weights = None + buffer = None self.zmq_socket.send(b"") - if buffer is not None: - del buffer - if weights is not None: - del weights gc.collect() torch.cuda.empty_cache() return True diff --git a/nemo_rl/models/policy/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/dtensor_policy_worker_v2.py index a765ad243a..55801122be 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/dtensor_policy_worker_v2.py @@ -1700,7 +1700,10 @@ def dtensor_params_generator(): # Convert DTensor to full tensor for streaming full_tensor = tensor.full_tensor() # Convert to target dtype - yield name, full_tensor.to(self.dtype, non_blocking=True).contiguous() + yield ( + name, + full_tensor.to(self.dtype, non_blocking=True).contiguous(), + ) else: # Convert to target dtype yield name, tensor.to(self.dtype, non_blocking=True).contiguous() From ebe08f463107db73ef46ac4905399eda6c898e05 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Tue, 7 Oct 2025 22:10:29 -0700 Subject: [PATCH 05/20] docstring Signed-off-by: Zhiyu Li --- nemo_rl/algorithms/grpo.py | 1 + nemo_rl/models/generation/vllm/vllm_backend.py | 17 +++++++---------- nemo_rl/models/generation/vllm/vllm_worker.py | 9 +-------- .../models/generation/vllm/vllm_worker_async.py | 9 +-------- nemo_rl/models/policy/dtensor_policy_worker.py | 5 ++++- .../models/policy/dtensor_policy_worker_v2.py | 5 ++++- nemo_rl/models/policy/megatron_policy_worker.py | 8 ++++++-- nemo_rl/models/policy/utils.py | 6 +++--- 8 files changed, 27 insertions(+), 33 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index b35406c42c..8cbf838ab8 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -732,6 +732,7 @@ def refit_policy_generation( _refit_buffer_size_gb: The size of the buffer to use for refitting. If it is None, the buffer size will be computed by the remaining memory. This parameter is primarily used for testing. + timer: Optional Timer used to time the prepare/transfer/update phase """ if colocated_inference: policy.offload_before_refit() diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 9368c1868e..d17fdd27ad 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -57,6 +57,7 @@ def init_collective( ) def report_device_id(self) -> str: + """Retrieve the UUID of the current CUDA device.""" from nemo_rl.utils.nvml import get_device_uuid return get_device_uuid(self.device.index) @@ -72,14 +73,13 @@ def maybe_init_zmq(self): self.zmq_socket = self.zmq_context.socket( # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored zmq.REP ) - # Set receive timeout to 30 seconds to avoid hanging indefinitely - self.zmq_socket.setsockopt( - zmq.RCVTIMEO, 30000 - ) # 30 seconds in milliseconds + self.zmq_socket.setsockopt(zmq.SNDTIMEO, 30000) # 30s + self.zmq_socket.setsockopt(zmq.RCVTIMEO, 30000) # 30s + self.zmq_socket.setsockopt(zmq.LINGER, 0) self.zmq_socket.connect(self.get_zmq_address()) def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: - """Prepare the info for refit. + """Prepare state dict metadata for weight refitting and IPC streaming. Args: state_dict_info (dict): A dictionary containing the info for refit. @@ -89,10 +89,7 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: @wrap_with_nvtx_name("vllm_internal_worker_extension/update_weights_via_ipc_zmq") def update_weights_via_ipc_zmq(self) -> bool: - """Update weights from local IPC handles via ZMQ socket. - - Args: - None + """Receive and update model weights via ZMQ IPC socket. Returns: bool: True if weights were successfully updated. @@ -168,7 +165,7 @@ def update_weights_via_ipc_zmq(self) -> bool: except Exception as e: print( - f"Error in VllmInternalWorkerExtension.update_weights_from_ipc_handles: {e}" + f"Error in VllmInternalWorkerExtension.update_weights_via_ipc_zmq: {e}" ) return False diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index 85e8cf8b49..3c3d316c32 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -708,14 +708,7 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: @wrap_with_nvtx_name("vllm_genertion_worker/update_weights_via_ipc_zmq") def update_weights_via_ipc_zmq(self) -> bool: - """Update weights from IPC handles via ZMQ socket. - - Args: - None - - Returns: - bool: True if weights were successfully updated, False otherwise. - """ + """Update weights from IPC handles via ZMQ socket.""" try: assert self.llm is not None, ( "Attempting to update weights with either an uninitialized vLLM or non-model-owner" diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index bc28f45b86..d0b379b126 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -853,14 +853,7 @@ async def prepare_refit_info_async(self, state_dict_info: dict[str, Any]) -> Non async def update_weights_via_ipc_zmq_async( self, ) -> bool: - """Async version of update_weights_from_ipc_handles. - - Args: - None - - Returns: - bool: True if weights were successfully updated, False otherwise. - """ + """Async version of update_weights_via_ipc_zmq.""" try: assert self.llm is not None, ( "Attempting to update weights with either an uninitialized vLLM or non-model-owner" diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 2fec0ca819..c9f17fe559 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -159,6 +159,7 @@ def __init__( init_reference_model: bool = True, **kwargs: Any, ): + """Initialize the DTensorPolicyWorker.""" self.tokenizer = tokenizer self.processor = processor self.is_vlm = processor is not None @@ -1711,6 +1712,7 @@ def maybe_init_zmq(self): @torch.no_grad() def prepare_refit_info(self) -> Optional[dict[str, Any]]: + """Prepare state dict metadata for weight refitting and IPC streaming.""" state_dict_info = {} for name, tensor in self.model.state_dict().items(): # all tensor will be casted to self.dtype in stream_weights_via_ipc_zmq/broadcast_weights_for_collective @@ -1728,6 +1730,7 @@ def get_free_memory_bytes(self) -> int: @torch.no_grad() @wrap_with_nvtx_name("dtensor_policy_worker/stream_weights_via_ipc_zmq") def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int = 0) -> None: + """Stream model weights to peer process via ZMQ IPC socket.""" self.maybe_init_zmq() from nemo_rl.models.policy.utils import stream_weights_via_ipc_zmq_impl @@ -1839,7 +1842,7 @@ def offload_before_refit(self) -> None: @torch.no_grad() @wrap_with_nvtx_name("dtensor_policy_worker/offload_after_refit") def offload_after_refit(self) -> None: - # Offload as much as possible on the CPU + """Offload as much as possible on the CPU.""" self.model = self.move_to_cpu(self.model) self.model.eval() torch.randn(1).cuda() # wake up torch allocator diff --git a/nemo_rl/models/policy/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/dtensor_policy_worker_v2.py index 55801122be..224c4873cf 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/dtensor_policy_worker_v2.py @@ -122,6 +122,7 @@ def __init__( init_reference_model: bool = True, **kwargs: Any, ): + """Initialize the DTensorPolicyWorkerV2.""" self.tokenizer = tokenizer self.processor = processor self.is_vlm = processor is not None @@ -1672,6 +1673,7 @@ def maybe_init_zmq(self): @torch.no_grad() def prepare_refit_info(self) -> Optional[dict[str, Any]]: + """Prepare state dict metadata for weight refitting and IPC streaming.""" state_dict_info = {} for name, tensor in self.model.state_dict().items(): # all tensor will be casted to self.dtype in stream_weights_via_ipc_zmq/broadcast_weights_for_collective @@ -1689,6 +1691,7 @@ def get_free_memory_bytes(self) -> int: @torch.no_grad() @wrap_with_nvtx_name("dtensor_policy_worker_v2/stream_weights_via_ipc_zmq") def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int = 0) -> None: + """Stream model weights to peer process via ZMQ IPC socket.""" self.maybe_init_zmq() from nemo_rl.models.policy.utils import stream_weights_via_ipc_zmq_impl @@ -1800,7 +1803,7 @@ def offload_before_refit(self) -> None: @torch.no_grad() @wrap_with_nvtx_name("dtensor_policy_worker_v2/offload_after_refit") def offload_after_refit(self) -> None: - # Offload as much as possible on the CPU + """Offload as much as possible on the CPU.""" self.model = self.move_to_cpu(self.model) self.model.eval() torch.randn(1).cuda() # wake up torch allocator diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 1ae912ae2c..5006802f03 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1861,12 +1861,15 @@ def maybe_init_zmq(self): if not hasattr(self, "zmq_socket"): self.zmq_context = zmq.Context() self.zmq_socket = self.zmq_context.socket(zmq.REQ) + self.zmq_socket.setsockopt(zmq.SNDTIMEO, 30000) # 30s + self.zmq_socket.setsockopt(zmq.RCVTIMEO, 30000) # 30s + self.zmq_socket.setsockopt(zmq.LINGER, 0) self.zmq_socket.bind(self.get_zmq_address()) @torch.no_grad() @wrap_with_nvtx_name("megatron_policy_worker/prepare_refit_info") def prepare_refit_info(self) -> None: - # Get parameter info for refit / mcore side info + """Prepare state dict metadata for weight refitting and IPC streaming.""" self.refit_param_info_mcore = self._calculate_refit_param_info() # Collect tensor metadata for refit / hf side info @@ -1942,6 +1945,7 @@ def get_free_memory_bytes(self) -> int: @torch.no_grad() @wrap_with_nvtx_name("megatron_policy_worker/stream_weights_via_ipc_zmq") def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int = 0) -> None: + """Stream model weights to peer process via ZMQ IPC socket.""" self.maybe_init_zmq() from nemo_rl.models.policy.utils import stream_weights_via_ipc_zmq_impl @@ -2054,9 +2058,9 @@ def offload_before_refit(self): @wrap_with_nvtx_name("megatron_policy_worker/offload_after_refit") def offload_after_refit(self): + """Offload as much as possible on the CPU.""" no_grad = torch.no_grad() no_grad.__enter__() - # Offload as much as possible on the CPU self.model = self.move_model(self.model, "cpu") self.model.eval() torch.randn(1).cuda() # wake up torch allocator diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index a2a47fd56f..79fb4c51e1 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -257,11 +257,11 @@ def calculate_aligned_size(size_bytes: int, alignment: int = 512) -> int: """Calculate aligned size for memory alignment. Args: - size_bytes: Size in bytes to align - alignment: Alignment boundary in bytes (default 512) + size_bytes(int): Size in bytes to align + alignment(int): Alignment boundary in bytes (default 512) Returns: - Aligned size in bytes + Aligned size in bytes(int). """ return int(((size_bytes + alignment - 1) // alignment) * alignment) From 46c4c561a90e6599a1e158d1ce94144517eca9a4 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Wed, 8 Oct 2025 10:03:06 -0700 Subject: [PATCH 06/20] Update nemo_rl/models/generation/vllm/vllm_backend.py Co-authored-by: Guyue Huang <140554423+guyueh1@users.noreply.github.com> Signed-off-by: Zhiyu Li --- nemo_rl/models/generation/vllm/vllm_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index d17fdd27ad..7d4047615d 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -73,7 +73,7 @@ def maybe_init_zmq(self): self.zmq_socket = self.zmq_context.socket( # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored zmq.REP ) - self.zmq_socket.setsockopt(zmq.SNDTIMEO, 30000) # 30s + self.zmq_socket.setsockopt(zmq.SNDTIMEO, 30000) # set timeout to 30 seconds self.zmq_socket.setsockopt(zmq.RCVTIMEO, 30000) # 30s self.zmq_socket.setsockopt(zmq.LINGER, 0) self.zmq_socket.connect(self.get_zmq_address()) From 39770f110a02b8e2bff087591363ff70a49dc9d5 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Wed, 8 Oct 2025 10:03:13 -0700 Subject: [PATCH 07/20] Update nemo_rl/models/generation/vllm/vllm_backend.py Co-authored-by: Guyue Huang <140554423+guyueh1@users.noreply.github.com> Signed-off-by: Zhiyu Li --- .../models/generation/vllm/vllm_backend.py | 41 ++++++++++++------- nemo_rl/models/policy/utils.py | 20 ++++++--- 2 files changed, 41 insertions(+), 20 deletions(-) diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 7d4047615d..113b00a2f6 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -15,9 +15,10 @@ from typing import Any import torch -import zmq from torch.multiprocessing.reductions import rebuild_cuda_tensor +import zmq +from nemo_rl.models.policy.utils import IPCProtocol, calculate_aligned_size from nemo_rl.utils.nsys import wrap_with_nvtx_name from nemo_rl.utils.packed_tensor import packed_broadcast_consumer @@ -32,6 +33,25 @@ ) +def rebuild_cuda_tensor_from_ipc( + cuda_ipc_handle: tuple, device_id: int +) -> torch.Tensor: + """Rebuild a CUDA tensor from an IPC handle. + + Args: + cuda_ipc_handle: Tuple containing the CUDA IPC handle data + device_id: Target CUDA device ID + + Returns: + Reconstructed CUDA tensor on the target device + """ + func = rebuild_cuda_tensor + args = cuda_ipc_handle[0] + list_args = list(args) + list_args[6] = device_id + return func(*list_args) + + class VllmInternalWorkerExtension: def init_collective( self, @@ -74,7 +94,7 @@ def maybe_init_zmq(self): zmq.REP ) self.zmq_socket.setsockopt(zmq.SNDTIMEO, 30000) # set timeout to 30 seconds - self.zmq_socket.setsockopt(zmq.RCVTIMEO, 30000) # 30s + self.zmq_socket.setsockopt(zmq.RCVTIMEO, 30000) # set timeout to 30 seconds self.zmq_socket.setsockopt(zmq.LINGER, 0) self.zmq_socket.connect(self.get_zmq_address()) @@ -99,24 +119,17 @@ def update_weights_via_ipc_zmq(self) -> bool: try: self.maybe_init_zmq() - from nemo_rl.models.policy.utils import calculate_aligned_size - while True: # Blocking receive with timeout (this is the main operation) payload = self.zmq_socket.recv_pyobj() - if payload == "complete": + if payload == IPCProtocol.COMPLETE: # means the update is done - self.zmq_socket.send(b"") + self.zmq_socket.send(IPCProtocol.ACK.value.encode()) break - packed_tensor_handle, list_keys, used_bytes = payload - device_id = self.device.index - func = rebuild_cuda_tensor - args = packed_tensor_handle[0] - list_args = list(args) - list_args[6] = device_id - buffer = func(*list_args) + ipc_handle, list_keys, used_bytes = payload + buffer = rebuild_cuda_tensor_from_ipc(ipc_handle, self.device.index) weights = [] offset = 0 @@ -157,7 +170,7 @@ def update_weights_via_ipc_zmq(self) -> bool: del weights, buffer weights = None buffer = None - self.zmq_socket.send(b"") + self.zmq_socket.send(IPCProtocol.ACK.value.encode()) gc.collect() torch.cuda.empty_cache() diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index 79fb4c51e1..55290579b6 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -15,6 +15,7 @@ import gc import importlib import os +from enum import Enum from typing import Any, Dict import torch @@ -71,6 +72,12 @@ } +class IPCProtocol(Enum): + """IPC protocol constants for ZMQ weight streaming.""" + COMPLETE = "complete" + ACK = "ack" + + def resolve_model_class(model_name: str) -> Any: """Resolve the appropriate model class for a given model name.""" if NEMO_AUTOMODEL_AVAILABLE: @@ -271,7 +278,7 @@ def stream_weights_via_ipc_zmq_impl( ) -> None: """Shared implementation for streaming weights via IPC ZMQ with improved memory management. - Uses double buffering to enable overlapping communication while reusing buffers + Uses ping-pong double buffering to enable overlapping communication while reusing buffers to reduce memory allocation overhead and improve stability. Args: @@ -291,8 +298,9 @@ def send_buffer_group_overlap(buffer, param_names, used_bytes, await_recv) -> bo if await_recv: zmq_socket.recv() - serialized = (cuda_ipc_handle, tuple(param_names), used_bytes) - zmq_socket.send_pyobj(serialized) + # Payload tuple: (cuda_ipc_handle, param_names, used_bytes) + payload = (cuda_ipc_handle, param_names, used_bytes) + zmq_socket.send_pyobj(payload) return True # pending_recv = True def allocate_buffer(device): @@ -313,7 +321,7 @@ def pack_tensor(buffer, tensor, used_bytes) -> int: ) return used_bytes + calculate_aligned_size(tensor_bytes) - # Initialize double buffering system + # Initialize ping-pong double buffering buffer_a: torch.Tensor | None = None buffer_b: torch.Tensor | None = None current_buffer: torch.Tensor | None = None @@ -343,7 +351,7 @@ def pack_tensor(buffer, tensor, used_bytes) -> int: ) count_of_groups += 1 - # Switch buffers for double buffering + # Switch buffers for ping-pong double buffering current_buffer = buffer_b if current_buffer is buffer_a else buffer_a used_bytes, param_names = 0, [] @@ -364,7 +372,7 @@ def pack_tensor(buffer, tensor, used_bytes) -> int: # Final synchronization and completion signal torch.cuda.current_stream().synchronize() - zmq_socket.send_pyobj("complete") + zmq_socket.send_pyobj(IPCProtocol.COMPLETE) zmq_socket.recv() if rank == 0: From e63bf9f32df2b89ef6df6d3faea9cebeca91ea0c Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Wed, 8 Oct 2025 15:08:06 -0700 Subject: [PATCH 08/20] Update nemo_rl/models/generation/vllm/vllm_backend.py Co-authored-by: Guyue Huang <140554423+guyueh1@users.noreply.github.com> Signed-off-by: Zhiyu Li --- nemo_rl/models/generation/vllm/vllm_backend.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 113b00a2f6..e1427e19b5 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -175,7 +175,6 @@ def update_weights_via_ipc_zmq(self) -> bool: gc.collect() torch.cuda.empty_cache() return True - except Exception as e: print( f"Error in VllmInternalWorkerExtension.update_weights_via_ipc_zmq: {e}" From 22f5c10280e5c2b40d770842a124a33f6918d5ba Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Wed, 8 Oct 2025 16:32:35 -0700 Subject: [PATCH 09/20] lint Signed-off-by: Zhiyu Li --- nemo_rl/models/generation/vllm/vllm_backend.py | 2 +- nemo_rl/models/policy/utils.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index e1427e19b5..7eae3b572d 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -15,9 +15,9 @@ from typing import Any import torch +import zmq from torch.multiprocessing.reductions import rebuild_cuda_tensor -import zmq from nemo_rl.models.policy.utils import IPCProtocol, calculate_aligned_size from nemo_rl.utils.nsys import wrap_with_nvtx_name from nemo_rl.utils.packed_tensor import packed_broadcast_consumer diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index 55290579b6..945a45ac25 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -74,6 +74,7 @@ class IPCProtocol(Enum): """IPC protocol constants for ZMQ weight streaming.""" + COMPLETE = "complete" ACK = "ack" From 8d0b6bd30c42634580c8843cdcb4243af187a5c9 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Thu, 9 Oct 2025 09:51:35 -0700 Subject: [PATCH 10/20] Update nemo_rl/algorithms/grpo.py Co-authored-by: Yuki Huang Signed-off-by: Zhiyu Li --- nemo_rl/algorithms/grpo.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 8cbf838ab8..27ca566768 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -752,9 +752,10 @@ def refit_policy_generation( if _refit_buffer_size_gb is not None: buffer_size_bytes = _refit_buffer_size_gb * (1024**3) else: - memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.15") + memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.3") + # divides by 2 since we have 2 buffers for overlap buffer_size_bytes = int( - policy.get_free_memory_bytes() * float(memory_ratio) + policy.get_free_memory_bytes() * float(memory_ratio) / 2 ) futures_train = policy.stream_weights_via_ipc_zmq( From e4b100335ea528034383af623793e998426a4832 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Thu, 9 Oct 2025 10:12:07 -0700 Subject: [PATCH 11/20] better comments Signed-off-by: Zhiyu Li --- nemo_rl/algorithms/grpo.py | 3 +-- nemo_rl/models/policy/dtensor_policy_worker.py | 4 ++-- nemo_rl/models/policy/dtensor_policy_worker_v2.py | 4 ++-- nemo_rl/models/policy/megatron_policy_worker.py | 4 ++-- nemo_rl/models/policy/utils.py | 4 +++- 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 27ca566768..b890ad677e 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -753,9 +753,8 @@ def refit_policy_generation( buffer_size_bytes = _refit_buffer_size_gb * (1024**3) else: memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.3") - # divides by 2 since we have 2 buffers for overlap buffer_size_bytes = int( - policy.get_free_memory_bytes() * float(memory_ratio) / 2 + policy.get_free_memory_bytes() * float(memory_ratio) ) futures_train = policy.stream_weights_via_ipc_zmq( diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index c9f17fe559..29a84ebb1b 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -1705,8 +1705,8 @@ def maybe_init_zmq(self): if not hasattr(self, "zmq_socket"): self.zmq_context = zmq.Context() self.zmq_socket = self.zmq_context.socket(zmq.REQ) - self.zmq_socket.setsockopt(zmq.SNDTIMEO, 30000) # 30s - self.zmq_socket.setsockopt(zmq.RCVTIMEO, 30000) # 30s + self.zmq_socket.setsockopt(zmq.SNDTIMEO, 30000) # set timeout to 30 seconds + self.zmq_socket.setsockopt(zmq.RCVTIMEO, 30000) # set timeout to 30 seconds self.zmq_socket.setsockopt(zmq.LINGER, 0) self.zmq_socket.bind(self.get_zmq_address()) diff --git a/nemo_rl/models/policy/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/dtensor_policy_worker_v2.py index 224c4873cf..6ffdab4573 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/dtensor_policy_worker_v2.py @@ -1666,8 +1666,8 @@ def maybe_init_zmq(self): if not hasattr(self, "zmq_socket"): self.zmq_context = zmq.Context() self.zmq_socket = self.zmq_context.socket(zmq.REQ) - self.zmq_socket.setsockopt(zmq.SNDTIMEO, 30000) # 30s - self.zmq_socket.setsockopt(zmq.RCVTIMEO, 30000) # 30s + self.zmq_socket.setsockopt(zmq.SNDTIMEO, 30000) # set timeout to 30 seconds + self.zmq_socket.setsockopt(zmq.RCVTIMEO, 30000) # set timeout to 30 seconds self.zmq_socket.setsockopt(zmq.LINGER, 0) self.zmq_socket.bind(self.get_zmq_address()) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 5006802f03..1b7b1d1186 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1861,8 +1861,8 @@ def maybe_init_zmq(self): if not hasattr(self, "zmq_socket"): self.zmq_context = zmq.Context() self.zmq_socket = self.zmq_context.socket(zmq.REQ) - self.zmq_socket.setsockopt(zmq.SNDTIMEO, 30000) # 30s - self.zmq_socket.setsockopt(zmq.RCVTIMEO, 30000) # 30s + self.zmq_socket.setsockopt(zmq.SNDTIMEO, 30000) # set timeout to 30 seconds + self.zmq_socket.setsockopt(zmq.RCVTIMEO, 30000) # set timeout to 30 seconds self.zmq_socket.setsockopt(zmq.LINGER, 0) self.zmq_socket.bind(self.get_zmq_address()) diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index 945a45ac25..9cfbf1c3d5 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -284,11 +284,13 @@ def stream_weights_via_ipc_zmq_impl( Args: params_generator: Generator yielding (name, tensor) pairs - buffer_size_bytes: Size of buffer in bytes for batching parameters + buffer_size_bytes: total size of buffer in bytes for batching parameters zmq_socket: ZMQ socket for communication rank: Worker rank for logging worker_name: Name of the worker for logging """ + # Divide total buffer size by 2 because we use two individual buffers (ping-pong) for overlapping communication. + buffer_size_bytes = buffer_size_bytes // 2 def send_buffer_group_overlap(buffer, param_names, used_bytes, await_recv) -> bool: """Send a group of parameters and return new pending_recv state.""" From 648e43dfc407952bc0be57c912e6bf9198db395a Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Wed, 15 Oct 2025 05:06:42 -0700 Subject: [PATCH 12/20] fix test failure Signed-off-by: Zhiyu Li --- nemo_rl/models/policy/dtensor_policy_worker.py | 3 +++ nemo_rl/models/policy/dtensor_policy_worker_v2.py | 3 +++ tests/unit/models/generation/test_vllm_generation.py | 12 ++++++------ 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 29a84ebb1b..d823d7ff53 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -1732,6 +1732,9 @@ def get_free_memory_bytes(self) -> int: def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int = 0) -> None: """Stream model weights to peer process via ZMQ IPC socket.""" self.maybe_init_zmq() + # Manually move model to cuda for cpu offload case + if self.cpu_offload: + self.model = self.move_to_cuda(self.model) from nemo_rl.models.policy.utils import stream_weights_via_ipc_zmq_impl diff --git a/nemo_rl/models/policy/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/dtensor_policy_worker_v2.py index 6ffdab4573..8fb447a043 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/dtensor_policy_worker_v2.py @@ -1693,6 +1693,9 @@ def get_free_memory_bytes(self) -> int: def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int = 0) -> None: """Stream model weights to peer process via ZMQ IPC socket.""" self.maybe_init_zmq() + # Manually move model to cuda for cpu offload case + if self.cpu_offload: + self.model = self.move_to_cuda(self.model) from nemo_rl.models.policy.utils import stream_weights_via_ipc_zmq_impl diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 478eac9a7c..4b1bc12b11 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -1594,11 +1594,11 @@ def test_vllm_weight_update_and_prefix_cache_reset( ) print("Updating vLLM weights from HF policy...") - grouped_param_keys = lm_policy.prepare_weights_for_ipc() - for keys in grouped_param_keys: - ipc_handles = lm_policy.get_weights_ipc_handles(keys) - update_success = vllm_policy.update_weights_from_ipc_handles(ipc_handles) - assert update_success, "Weight update should succeed" + + buffer_size_bytes = int(lm_policy.get_free_memory_bytes() * 0.3) + lm_policy.stream_weights_via_ipc_zmq(buffer_size_bytes=buffer_size_bytes) + update_success = vllm_policy.update_weights_via_ipc_zmq() + assert update_success, "Weight update should succeed" print("vLLM weights successfully updated.") print("Running Generation 2 (Weights Updated, Cache Still Active)...") @@ -1675,7 +1675,7 @@ def test_vllm_weight_update_memory(cluster, tokenizer): lm_policy, vllm_policy, vllm_config["colocated"]["enabled"], - _refit_buffer_size_gb=1, + _refit_buffer_size_gb=1.5, ) gpu_infos = ray.get([w.get_gpu_info.remote() for w in workers]) From 31d846671500794d77402ca9cc3f8194f69001a0 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Thu, 16 Oct 2025 14:32:06 -0700 Subject: [PATCH 13/20] fix another failure Signed-off-by: Zhiyu Li --- tests/unit/models/generation/test_vllm_generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 4b1bc12b11..84448dca41 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -2133,7 +2133,7 @@ def test_vllm_megatron_weight_update_memory(cluster, tokenizer): megatron_policy, vllm_policy, vllm_config["colocated"]["enabled"], - _refit_buffer_size_gb=1, + _refit_buffer_size_gb=1.5, ) gpu_infos = ray.get([w.get_gpu_info.remote() for w in workers]) From 378db630e08ab74771ca1ce07f06bfb61b94ef85 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Mon, 20 Oct 2025 03:15:16 -0700 Subject: [PATCH 14/20] add cleanup/shutdown Signed-off-by: Zhiyu Li --- nemo_rl/models/generation/vllm/vllm_backend.py | 7 +++++++ nemo_rl/models/generation/vllm/vllm_worker.py | 3 +++ nemo_rl/models/generation/vllm/vllm_worker_async.py | 4 +++- nemo_rl/models/policy/dtensor_policy_worker.py | 4 ++++ nemo_rl/models/policy/dtensor_policy_worker_v2.py | 4 ++++ nemo_rl/models/policy/lm_policy.py | 2 +- nemo_rl/models/policy/megatron_policy_worker.py | 5 ++++- 7 files changed, 26 insertions(+), 3 deletions(-) diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 7eae3b572d..d4f8b44acd 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -226,6 +226,13 @@ def _load_model_weights(weights, model_runner): return True + def cleanup(self) -> None: + """Shutdown and cleanup resources.""" + # Close ZMQ socket and context if they exist + if hasattr(self, "zmq_socket"): + self.zmq_socket.close() + self.zmq_context.term() + def start_gpu_profiling(self) -> None: """Start GPU profiling.""" torch.cuda.profiler.start() diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index 3c3d316c32..7ce826a27f 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -825,6 +825,9 @@ def shutdown(self) -> bool: """Clean up vLLM resources.""" try: if self.llm is not None: + # Clean up extension resources (e.g., ZMQ sockets) + self.llm.collective_rpc("cleanup", args=tuple()) + # Explicitly delete the engine. This may trigger its __del__ method. del self.llm diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index d0b379b126..c456c62c03 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -977,10 +977,12 @@ async def wake_up_async(self, **kwargs): await self.llm.wake_up(**wake_up_args) - def shutdown(self) -> bool: + async def shutdown(self) -> bool: """Clean up vLLM resources.""" try: if self.llm is not None: + # Clean up extension resources (e.g., ZMQ sockets) + await self.llm.collective_rpc("cleanup", args=tuple()) try: self.llm.shutdown() except Exception as e_stop: diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index d823d7ff53..34ad4e16d5 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -1917,6 +1917,10 @@ def load_checkpoint( def shutdown(self) -> None: """Shutdown the policy.""" + # Clean up extension resources like ZMQ sockets + if hasattr(self, "zmq_socket"): + self.zmq_socket.close() + self.zmq_context.term() def start_gpu_profiling(self) -> None: """Start GPU profiling.""" diff --git a/nemo_rl/models/policy/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/dtensor_policy_worker_v2.py index 8fb447a043..7b60673e7f 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/dtensor_policy_worker_v2.py @@ -1901,6 +1901,10 @@ def load_checkpoint( def shutdown(self) -> None: """Shutdown the policy.""" + # Clean up extension resources like ZMQ sockets + if hasattr(self, "zmq_socket"): + self.zmq_socket.close() + self.zmq_context.term() def start_gpu_profiling(self) -> None: """Start GPU profiling.""" diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 8d7d6c3b16..811695b6fe 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -743,7 +743,7 @@ def __del__(self) -> None: the object is lost due to leaving a function scope. It's always recommended that the user calls worker_group.shutdown(). """ - self.worker_group.shutdown() + self.worker_group.shutdown(cleanup_method="shutdown") def start_gpu_profiling(self) -> None: """Start GPU profiling.""" diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 1b7b1d1186..1db9c4fa44 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -2216,7 +2216,10 @@ def load_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = Non def shutdown(self): """Shutdown the policy.""" - pass + # Clean up extension resources like ZMQ sockets + if hasattr(self, "zmq_socket"): + self.zmq_socket.close() + self.zmq_context.term() def start_gpu_profiling(self) -> None: """Start GPU profiling.""" From 078599eee438c7fab8f5f2480ba6a0b8de7f1e56 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Tue, 21 Oct 2025 01:34:50 -0700 Subject: [PATCH 15/20] add comments Signed-off-by: Zhiyu Li --- nemo_rl/algorithms/grpo.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index b890ad677e..deaf99aa39 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -752,6 +752,8 @@ def refit_policy_generation( if _refit_buffer_size_gb is not None: buffer_size_bytes = _refit_buffer_size_gb * (1024**3) else: + # Empirically sets ratio as 30% to maximize efficiency. + # The remaining 70% is a necessary buffer reserved for the parameter all-gathering across the expert-parallelism dimension. memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.3") buffer_size_bytes = int( policy.get_free_memory_bytes() * float(memory_ratio) From 5a4f180fa2f7bd53b6da170e4da08605f0e40440 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Tue, 21 Oct 2025 01:45:23 -0700 Subject: [PATCH 16/20] add test for stream_weights_via_ipc_zmq_impl Signed-off-by: Zhiyu Li --- .../models/generation/vllm/vllm_backend.py | 28 +-- nemo_rl/models/policy/utils.py | 12 ++ tests/unit/models/policy/test_utils.py | 200 ++++++++++++++++++ 3 files changed, 218 insertions(+), 22 deletions(-) diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index d4f8b44acd..9f072d76e0 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -15,10 +15,13 @@ from typing import Any import torch -import zmq -from torch.multiprocessing.reductions import rebuild_cuda_tensor -from nemo_rl.models.policy.utils import IPCProtocol, calculate_aligned_size +import zmq +from nemo_rl.models.policy.utils import ( + IPCProtocol, + calculate_aligned_size, + rebuild_cuda_tensor_from_ipc, +) from nemo_rl.utils.nsys import wrap_with_nvtx_name from nemo_rl.utils.packed_tensor import packed_broadcast_consumer @@ -33,25 +36,6 @@ ) -def rebuild_cuda_tensor_from_ipc( - cuda_ipc_handle: tuple, device_id: int -) -> torch.Tensor: - """Rebuild a CUDA tensor from an IPC handle. - - Args: - cuda_ipc_handle: Tuple containing the CUDA IPC handle data - device_id: Target CUDA device ID - - Returns: - Reconstructed CUDA tensor on the target device - """ - func = rebuild_cuda_tensor - args = cuda_ipc_handle[0] - list_args = list(args) - list_args[6] = device_id - return func(*list_args) - - class VllmInternalWorkerExtension: def init_collective( self, diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index 9cfbf1c3d5..b5a6c3a086 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -19,6 +19,7 @@ from typing import Any, Dict import torch +from torch.multiprocessing.reductions import rebuild_cuda_tensor from transformers import ( AutoConfig, AutoModelForCausalLM, @@ -393,3 +394,14 @@ def pack_tensor(buffer, tensor, used_bytes) -> int: # Force garbage collection and clear CUDA cache gc.collect() torch.cuda.empty_cache() + + +def rebuild_cuda_tensor_from_ipc( + cuda_ipc_handle: tuple, device_id: int +) -> torch.Tensor: + """Rebuild a CUDA tensor from an IPC handle.""" + func = rebuild_cuda_tensor + args = cuda_ipc_handle[0] + list_args = list(args) + list_args[6] = device_id + return func(*list_args) diff --git a/tests/unit/models/policy/test_utils.py b/tests/unit/models/policy/test_utils.py index 8fb4d8f8b2..daabcb7999 100644 --- a/tests/unit/models/policy/test_utils.py +++ b/tests/unit/models/policy/test_utils.py @@ -12,11 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import multiprocessing import os +import time import unittest.mock +import pytest +import torch + +import zmq from nemo_rl.models.policy.utils import ( + IPCProtocol, + calculate_aligned_size, get_megatron_checkpoint_dir, + rebuild_cuda_tensor_from_ipc, + stream_weights_via_ipc_zmq_impl, ) @@ -106,3 +116,193 @@ def test_function_prints_selected_directory(self, capsys): f"Using default megatron checkpoint dir: {expected_dir}" in captured.out ) assert result == expected_dir + + +def server_process( + zmq_addr: str, + known_tensors: list[tuple[str, torch.Tensor]], + buffer_size_bytes: int, + ready_queue: multiprocessing.Queue, +) -> None: + """Server process that streams tensors via IPC ZMQ.""" + try: + device = torch.device("cuda:0") + gpu_tensors = [(name, tensor.to(device)) for name, tensor in known_tensors] + + context = zmq.Context() + socket = context.socket(zmq.PAIR) + socket.bind(zmq_addr) + ready_queue.put(("ready", None)) + + stream_weights_via_ipc_zmq_impl( + (t for t in gpu_tensors), + buffer_size_bytes, + socket, + rank=0, + worker_name="test_server", + ) + + socket.close() + context.term() + except Exception as e: + ready_queue.put(("error", str(e))) + + +def client_process( + zmq_addr: str, + known_tensors_data: list[tuple[str, tuple, torch.dtype, torch.Tensor]], + buffer_size_bytes: int, + result_queue: multiprocessing.Queue, +) -> None: + """Client process that receives and validates tensors via IPC ZMQ.""" + try: + device = torch.device("cuda:0") + + # Prepare expected tensors on GPU + expected_tensors = { + name: tensor.to(device) for name, _, _, tensor in known_tensors_data + } + state_dict_info = { + name: (shape, dtype) for name, shape, dtype, _ in known_tensors_data + } + + context = zmq.Context() + socket = context.socket(zmq.PAIR) + socket.connect(zmq_addr) + + # Receive and validate loop + while True: + payload = socket.recv_pyobj() + if payload == IPCProtocol.COMPLETE: + socket.send(IPCProtocol.ACK.value.encode()) + break + + ipc_handle, list_keys, used_bytes = payload + buffer = rebuild_cuda_tensor_from_ipc(ipc_handle, device.index) + + offset = 0 + for key in list_keys: + shape, dtype = state_dict_info[key] + shape = torch.Size(shape) if isinstance(shape, list) else shape + size_in_bytes = dtype.itemsize * shape.numel() + + tensor = ( + buffer[offset : offset + size_in_bytes] + .view(dtype=dtype) + .view(shape) + ) + expected = expected_tensors[key] + + # Validate tensor + assert tensor.shape == expected.shape, f"Shape mismatch for {key}" + assert tensor.dtype == expected.dtype, f"Dtype mismatch for {key}" + assert torch.allclose(tensor, expected, rtol=1e-7, atol=1e-7), ( + f"Values mismatch for {key}" + ) + + offset += calculate_aligned_size(size_in_bytes) + + assert offset == used_bytes, f"Offset mismatch: {offset} != {used_bytes}" + socket.send(b"") + + socket.close() + context.term() + result_queue.put(("success", "All tensors validated")) + except Exception as e: + result_queue.put(("error", str(e))) + + +class TestStreamWeightsViaIPC: + """Test suite for IPC weight streaming functionality.""" + + TIMEOUT = 10 + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.parametrize( + "test_case,tensor_specs,buffer_size_bytes,test_description", + [ + ( + "large_buffer", + [ + ("tensor_1", (10, 20), torch.float32), # 0.78KB + ("tensor_2", (5, 15, 25), torch.float32), # 7.32KB + ("tensor_3", (100,), torch.float16), # 0.20KB + ("tensor_4", (50, 50), torch.bfloat16), # 4.88KB + ("tensor_5", (8, 16, 32), torch.float32), # 16.00KB + ], # Total: 29.18KB + 100 * 1024, # 100 KB - large buffer for single batch (50KB per side) + "Test with various shapes/dtypes in large buffer (single batch)", + ), + ( + "small_buffer", + [ + ("small_1", (30, 30), torch.float32), # 3.52KB + ("small_2", (20, 40), torch.float16), # 1.56KB + ("small_3", (128,), torch.float32), # 0.50KB + ("small_4", (25, 35), torch.float32), # 3.42KB + ], # Total: 9.00KB + 10 * 1024, # 10 KB - forces multiple batches (5KB per side) + "Test with small buffer forcing multiple batches", + ), + ], + ) + def test_stream_weights_via_ipc_zmq_impl( + self, test_case, tensor_specs, buffer_size_bytes, test_description + ): + """Test streaming weights via IPC ZMQ between server and client processes.""" + # Generate test tensors + known_tensors = [ + (name, torch.randn(*shape, dtype=dtype)) + for name, shape, dtype in tensor_specs + ] + known_tensors_data = [ + (name, list(t.shape), t.dtype, t) for name, t in known_tensors + ] + + # Create unique socket path and queues + socket_path = f"/tmp/test_ipc_zmq_{test_case}_{os.getpid()}_{time.time()}" + zmq_addr = f"ipc://{socket_path}" + + mp_context = multiprocessing.get_context("spawn") + ready_queue = mp_context.Queue() + result_queue = mp_context.Queue() + + # Start server and client + server_proc = mp_context.Process( + target=server_process, + args=(zmq_addr, known_tensors, buffer_size_bytes, ready_queue), + ) + server_proc.start() + + status, msg = ready_queue.get(timeout=self.TIMEOUT) + assert status == "ready", f"Server failed: {msg}" + + client_proc = mp_context.Process( + target=client_process, + args=(zmq_addr, known_tensors_data, buffer_size_bytes, result_queue), + ) + client_proc.start() + + # Wait and validate + try: + server_proc.join(timeout=self.TIMEOUT) + client_proc.join(timeout=self.TIMEOUT) + + assert server_proc.exitcode == 0, f"Server failed: {server_proc.exitcode}" + assert client_proc.exitcode == 0, f"Client failed: {client_proc.exitcode}" + + status, msg = result_queue.get(timeout=self.TIMEOUT) + assert status == "success", f"Validation failed: {msg}" + finally: + for proc in [server_proc, client_proc]: + if proc and proc.is_alive(): + proc.terminate() + proc.join(timeout=self.TIMEOUT) + if proc.is_alive(): + proc.kill() + + if os.path.exists(socket_path): + try: + os.unlink(socket_path) + except: + pass From 4a7cfd0ec8fd555be0d93a3a5af84814cc1ebf6c Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Tue, 21 Oct 2025 01:51:44 -0700 Subject: [PATCH 17/20] lint Signed-off-by: Zhiyu Li --- nemo_rl/models/generation/vllm/vllm_backend.py | 2 +- tests/unit/models/policy/test_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 9f072d76e0..0bf2b224d3 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -15,8 +15,8 @@ from typing import Any import torch - import zmq + from nemo_rl.models.policy.utils import ( IPCProtocol, calculate_aligned_size, diff --git a/tests/unit/models/policy/test_utils.py b/tests/unit/models/policy/test_utils.py index daabcb7999..ccf7a43a35 100644 --- a/tests/unit/models/policy/test_utils.py +++ b/tests/unit/models/policy/test_utils.py @@ -19,8 +19,8 @@ import pytest import torch - import zmq + from nemo_rl.models.policy.utils import ( IPCProtocol, calculate_aligned_size, From 7a856340397c93184c275e3a7c3c72702c56e931 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Tue, 21 Oct 2025 17:51:16 -0700 Subject: [PATCH 18/20] better error message Signed-off-by: Zhiyu Li --- nemo_rl/models/policy/lm_policy.py | 3 +- tests/unit/models/policy/test_utils.py | 70 ++++++++++++++++++++------ 2 files changed, 58 insertions(+), 15 deletions(-) diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 811695b6fe..0a111c1e31 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -743,7 +743,8 @@ def __del__(self) -> None: the object is lost due to leaving a function scope. It's always recommended that the user calls worker_group.shutdown(). """ - self.worker_group.shutdown(cleanup_method="shutdown") + if hasattr(self, "worker_group"): + self.worker_group.shutdown(cleanup_method="shutdown") def start_gpu_profiling(self) -> None: """Start GPU profiling.""" diff --git a/tests/unit/models/policy/test_utils.py b/tests/unit/models/policy/test_utils.py index ccf7a43a35..8145cfcc25 100644 --- a/tests/unit/models/policy/test_utils.py +++ b/tests/unit/models/policy/test_utils.py @@ -14,13 +14,15 @@ import multiprocessing import os +import sys import time +import traceback import unittest.mock import pytest import torch -import zmq +import zmq from nemo_rl.models.policy.utils import ( IPCProtocol, calculate_aligned_size, @@ -131,6 +133,10 @@ def server_process( context = zmq.Context() socket = context.socket(zmq.PAIR) + socket.setsockopt(zmq.LINGER, 0) # Close immediately on error + socket.setsockopt( + zmq.RCVTIMEO, 5000 + ) # 5 second timeout on recv to detect client disconnection socket.bind(zmq_addr) ready_queue.put(("ready", None)) @@ -141,11 +147,18 @@ def server_process( rank=0, worker_name="test_server", ) - + except Exception as e: + import sys + import traceback + + error_details = f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()}" + ready_queue.put(("error", error_details)) + sys.exit( + 1 + ) # Exit with non-zero code so check_process_error detects the failure + finally: socket.close() context.term() - except Exception as e: - ready_queue.put(("error", str(e))) def client_process( @@ -168,6 +181,9 @@ def client_process( context = zmq.Context() socket = context.socket(zmq.PAIR) + socket.setsockopt( + zmq.LINGER, 0 + ) # Close immediately on error, don't wait for pending sends socket.connect(zmq_addr) # Receive and validate loop @@ -196,7 +212,7 @@ def client_process( # Validate tensor assert tensor.shape == expected.shape, f"Shape mismatch for {key}" assert tensor.dtype == expected.dtype, f"Dtype mismatch for {key}" - assert torch.allclose(tensor, expected, rtol=1e-7, atol=1e-7), ( + assert torch.allclose(tensor + 1, expected, rtol=1e-7, atol=1e-7), ( f"Values mismatch for {key}" ) @@ -205,11 +221,38 @@ def client_process( assert offset == used_bytes, f"Offset mismatch: {offset} != {used_bytes}" socket.send(b"") - socket.close() - context.term() result_queue.put(("success", "All tensors validated")) except Exception as e: - result_queue.put(("error", str(e))) + error_details = f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()}" + result_queue.put(("error", error_details)) + sys.exit(1) + finally: + socket.close() + context.term() + + +def check_process_error( + proc: multiprocessing.Process, + queue: multiprocessing.Queue, + process_name: str, +) -> None: + """Check if a process failed and assert with detailed error message if available.""" + if proc.exitcode == 0: + return + + # Get error details from queue + error_msg = None + while not queue.empty(): + status, msg = queue.get_nowait() + if status == "error": + error_msg = msg + break + + if proc.exitcode is None: + assert False, f"{process_name} timed out" + else: + details = f"\n{error_msg}" if error_msg else "" + assert False, f"{process_name} failed (exitcode={proc.exitcode}){details}" class TestStreamWeightsViaIPC: @@ -288,9 +331,11 @@ def test_stream_weights_via_ipc_zmq_impl( server_proc.join(timeout=self.TIMEOUT) client_proc.join(timeout=self.TIMEOUT) - assert server_proc.exitcode == 0, f"Server failed: {server_proc.exitcode}" - assert client_proc.exitcode == 0, f"Client failed: {client_proc.exitcode}" + # Check client first since client failure often causes server to fail + check_process_error(client_proc, result_queue, "Client") + check_process_error(server_proc, ready_queue, "Server") + # Verify client success message status, msg = result_queue.get(timeout=self.TIMEOUT) assert status == "success", f"Validation failed: {msg}" finally: @@ -302,7 +347,4 @@ def test_stream_weights_via_ipc_zmq_impl( proc.kill() if os.path.exists(socket_path): - try: - os.unlink(socket_path) - except: - pass + os.unlink(socket_path) From 1da3381f2fa1fb49d9a44f874bde1e74d4f03291 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Tue, 21 Oct 2025 18:44:47 -0700 Subject: [PATCH 19/20] fix lint Signed-off-by: Zhiyu Li --- tests/unit/models/policy/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/models/policy/test_utils.py b/tests/unit/models/policy/test_utils.py index 8145cfcc25..2e56025d73 100644 --- a/tests/unit/models/policy/test_utils.py +++ b/tests/unit/models/policy/test_utils.py @@ -21,8 +21,8 @@ import pytest import torch - import zmq + from nemo_rl.models.policy.utils import ( IPCProtocol, calculate_aligned_size, @@ -212,7 +212,7 @@ def client_process( # Validate tensor assert tensor.shape == expected.shape, f"Shape mismatch for {key}" assert tensor.dtype == expected.dtype, f"Dtype mismatch for {key}" - assert torch.allclose(tensor + 1, expected, rtol=1e-7, atol=1e-7), ( + assert torch.allclose(tensor, expected, rtol=1e-7, atol=1e-7), ( f"Values mismatch for {key}" ) From 62acf7a7a49256d993a882674400c0633ebcd8ea Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Tue, 21 Oct 2025 23:02:03 -0700 Subject: [PATCH 20/20] longer timeout for tests Signed-off-by: Zhiyu Li --- tests/unit/models/policy/test_utils.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/tests/unit/models/policy/test_utils.py b/tests/unit/models/policy/test_utils.py index 2e56025d73..0b90ab0fbf 100644 --- a/tests/unit/models/policy/test_utils.py +++ b/tests/unit/models/policy/test_utils.py @@ -134,9 +134,7 @@ def server_process( context = zmq.Context() socket = context.socket(zmq.PAIR) socket.setsockopt(zmq.LINGER, 0) # Close immediately on error - socket.setsockopt( - zmq.RCVTIMEO, 5000 - ) # 5 second timeout on recv to detect client disconnection + socket.setsockopt(zmq.RCVTIMEO, 10000) # 10 second timeout socket.bind(zmq_addr) ready_queue.put(("ready", None)) @@ -164,7 +162,6 @@ def server_process( def client_process( zmq_addr: str, known_tensors_data: list[tuple[str, tuple, torch.dtype, torch.Tensor]], - buffer_size_bytes: int, result_queue: multiprocessing.Queue, ) -> None: """Client process that receives and validates tensors via IPC ZMQ.""" @@ -181,9 +178,8 @@ def client_process( context = zmq.Context() socket = context.socket(zmq.PAIR) - socket.setsockopt( - zmq.LINGER, 0 - ) # Close immediately on error, don't wait for pending sends + socket.setsockopt(zmq.LINGER, 0) # Close immediately on error + socket.setsockopt(zmq.RCVTIMEO, 10000) # 10 second timeout socket.connect(zmq_addr) # Receive and validate loop @@ -258,9 +254,8 @@ def check_process_error( class TestStreamWeightsViaIPC: """Test suite for IPC weight streaming functionality.""" - TIMEOUT = 10 + TIMEOUT = 30 # 30 second timeout for additional overhead when running with coverage - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize( "test_case,tensor_specs,buffer_size_bytes,test_description", [ @@ -322,7 +317,7 @@ def test_stream_weights_via_ipc_zmq_impl( client_proc = mp_context.Process( target=client_process, - args=(zmq_addr, known_tensors_data, buffer_size_bytes, result_queue), + args=(zmq_addr, known_tensors_data, result_queue), ) client_proc.start()