Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,7 @@ def refit_policy_generation(
_refit_buffer_size_gb: The size of the buffer to use for refitting.
If it is None, the buffer size will be computed by the remaining memory.
This parameter is primarily used for testing.
timer: Optional Timer used to time the prepare/transfer/update phase
"""
if colocated_inference:
policy.offload_before_refit()
Expand All @@ -748,22 +749,24 @@ def refit_policy_generation(
update_success = False
if colocated_inference:
# get model param keys, which is grouped by size
grouped_param_keys = policy.prepare_weights_for_ipc(
_refit_buffer_size_gb=_refit_buffer_size_gb
)
total_num_keys = sum(len(k) for k in grouped_param_keys)
print(
f"[Refit] Split {total_num_keys} keys into {len(grouped_param_keys)} groups",
flush=True,
)
# do update
for keys in grouped_param_keys:
ipc_handles = policy.get_weights_ipc_handles(keys)
update_success = policy_generation.update_weights_from_ipc_handles(
ipc_handles
if _refit_buffer_size_gb is not None:
buffer_size_bytes = _refit_buffer_size_gb * (1024**3)
else:
# Empirically sets ratio as 30% to maximize efficiency.
# The remaining 70% is a necessary buffer reserved for the parameter all-gathering across the expert-parallelism dimension.
memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.3")
buffer_size_bytes = int(
policy.get_free_memory_bytes() * float(memory_ratio)
)
if not update_success:
break

futures_train = policy.stream_weights_via_ipc_zmq(
buffer_size_bytes=buffer_size_bytes
)
futures_inference = policy_generation.update_weights_via_ipc_zmq()
# wait for all futures to complete
ray.get(futures_train)
results = ray.get(futures_inference)
update_success = all(result for result in results if result is not None)
else:
# update weights through nccl
futures_train = policy.broadcast_weights_for_collective()
Expand Down
2 changes: 1 addition & 1 deletion nemo_rl/models/generation/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None:
"""Prepare the info for refit."""
raise NotImplementedError

def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool:
def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]:
"""Update the model weights from the given IPC handles."""
raise NotImplementedError

Expand Down
179 changes: 87 additions & 92 deletions nemo_rl/models/generation/vllm/vllm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import Any, Optional
import gc
from typing import Any

import torch
from torch.multiprocessing.reductions import rebuild_cuda_tensor
import zmq

from nemo_rl.models.policy.utils import (
IPCProtocol,
calculate_aligned_size,
rebuild_cuda_tensor_from_ipc,
)
from nemo_rl.utils.nsys import wrap_with_nvtx_name
from nemo_rl.utils.packed_tensor import packed_broadcast_consumer

Expand Down Expand Up @@ -56,124 +61,107 @@ def init_collective(
)

def report_device_id(self) -> str:
"""Retrieve the UUID of the current CUDA device."""
from nemo_rl.utils.nvml import get_device_uuid

return get_device_uuid(self.device.index)

def prepare_refit_info(
self, state_dict_info: Optional[dict[str, Any]] = None
) -> None:
"""Prepare the info for refit.

DtensorPolicyWorker:
colocated inference: state_dict_info is None
non-colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype)}
def get_zmq_address(self):
"""Get the ZMQ address for the current device."""
return f"ipc:///tmp/{self.report_device_id()}.sock"

MegatronPolicyWorker:
colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype, numel)}
non-colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype)}
"""
self.state_dict_info = state_dict_info # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored
def maybe_init_zmq(self):
"""Initialize the ZMQ socket if it doesn't exist."""
if not hasattr(self, "zmq_socket"):
self.zmq_context = zmq.Context() # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored
self.zmq_socket = self.zmq_context.socket( # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored
zmq.REP
)
self.zmq_socket.setsockopt(zmq.SNDTIMEO, 30000) # set timeout to 30 seconds
self.zmq_socket.setsockopt(zmq.RCVTIMEO, 30000) # set timeout to 30 seconds
self.zmq_socket.setsockopt(zmq.LINGER, 0)
self.zmq_socket.connect(self.get_zmq_address())

@wrap_with_nvtx_name(
"vllm_internal_worker_extension/update_weights_from_global_ipc_handles"
)
def update_weights_from_global_ipc_handles(self, global_device_ipc_handles):
"""Update weights from global IPC handles.
def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None:
"""Prepare state dict metadata for weight refitting and IPC streaming.

Args:
global_device_ipc_handles (dict): Dictionary mapping device UUIDs to parameter IPC handles.

Returns:
bool: True if weights were successfully updated.
state_dict_info (dict): A dictionary containing the info for refit.
e.g. {tensor_name: (shape, dtype)}
"""
device_uuid = self.report_device_id()
local_device_ipc_handles = global_device_ipc_handles[device_uuid]
return self.update_weights_from_local_ipc_handles(local_device_ipc_handles)

@wrap_with_nvtx_name(
"vllm_internal_worker_extension/update_weights_from_local_ipc_handles"
)
def update_weights_from_local_ipc_handles(self, local_device_ipc_handles):
"""Update weights from local IPC handles.
self.state_dict_info = state_dict_info # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored

Args:
local_device_ipc_handles (dict): parameter IPC handles for local device.
@wrap_with_nvtx_name("vllm_internal_worker_extension/update_weights_via_ipc_zmq")
def update_weights_via_ipc_zmq(self) -> bool:
"""Receive and update model weights via ZMQ IPC socket.

Returns:
bool: True if weights were successfully updated.
"""
try:
is_tensor_packed = local_device_ipc_handles[0]
if is_tensor_packed:
_, all_handles, list_keys = local_device_ipc_handles
else:
_, name_and_handle_list = local_device_ipc_handles
buffer = None
weights = None

device_id = self.device.index
weights = []
try:
self.maybe_init_zmq()
while True:
# Blocking receive with timeout (this is the main operation)
payload = self.zmq_socket.recv_pyobj()

if is_tensor_packed:
assert self.state_dict_info is not None, (
"state_dict_info is not prepared. "
"Please call prepare_refit_info when initializing the worker."
)
if payload == IPCProtocol.COMPLETE:
# means the update is done
self.zmq_socket.send(IPCProtocol.ACK.value.encode())
break

# Extract packed tensor from IPC handle
dtype_to_packed_tensor = {}
for dtype, tensor_handle in all_handles:
func = rebuild_cuda_tensor
args = tensor_handle[0]
list_args = list(args)
list_args[6] = device_id
tensor = func(*list_args)
dtype_to_packed_tensor[dtype] = tensor
ipc_handle, list_keys, used_bytes = payload
buffer = rebuild_cuda_tensor_from_ipc(ipc_handle, self.device.index)

weights = []
dtype_to_offset = defaultdict(lambda: 0)
offset = 0
for key in list_keys:
shape, dtype, size = self.state_dict_info[key]
shape, dtype = self.state_dict_info[key] # pyrefly
if isinstance(shape, list):
shape = torch.Size(shape)
size_in_bytes = dtype.itemsize * shape.numel()
weights.append(
(
key,
dtype_to_packed_tensor[dtype][
dtype_to_offset[dtype] : dtype_to_offset[dtype] + size
].view(*shape),
buffer[offset : offset + size_in_bytes]
.view(dtype=dtype)
.view(shape),
)
)
dtype_to_offset[dtype] += size

expected_sizes = {
dtype: tensor.numel()
for dtype, tensor in dtype_to_packed_tensor.items()
}
assert dtype_to_offset == expected_sizes, (
f"Packed tensor size mismatch: expected sizes from keys list {expected_sizes} != actual packed tensor sizes {dtype_to_offset}. "
f"This indicates the keys list order doesn't match the order used when packing tensors."
aligned_size = calculate_aligned_size(size_in_bytes)
offset += aligned_size
assert offset == used_bytes, (
"Offset is not equal to used bytes, usually indicate inaccurate info like keys or cached dtype in state_dict_info"
)
else:
# Process each handle to get the tensor
for name, handle in name_and_handle_list:
func = rebuild_cuda_tensor
args = handle[0]
list_args = list(args)
list_args[6] = device_id
tensor = func(*list_args)
weights.append((name, tensor))

# Load weights into the model
from nemo_rl.models.generation import fp8

if fp8.is_fp8_model(self.model_runner.vllm_config):
# the fp8 load_weights additionally casts bf16 weights into fp8
fp8.load_weights(weights, self.model_runner)
else:
self.model_runner.model.load_weights(weights=weights)

# Load weights into the model
from nemo_rl.models.generation import fp8

if fp8.is_fp8_model(self.model_runner.vllm_config):
# the fp8 load_weights additionally casts bf16 weights into fp8
fp8.load_weights(weights, self.model_runner)
else:
self.model_runner.model.load_weights(weights=weights)

torch.cuda.current_stream().synchronize()

# CRITICAL: Delete views before ACK to prevent corruption.
# 'weights' contains views into IPC shared memory. Even though load_weights()
# copied the data, Python may not garbage collect these view objects immediately.
# If sender reuses the buffer before GC runs, old views would read corrupted data.
# Explicit del ensures immediate cleanup before sending ACK.
del weights, buffer
weights = None
buffer = None
self.zmq_socket.send(IPCProtocol.ACK.value.encode())

gc.collect()
torch.cuda.empty_cache()
return True
except Exception as e:
print(
f"Error in VllmInternalWorkerExtension.update_weights_from_ipc_handles: {e}"
f"Error in VllmInternalWorkerExtension.update_weights_via_ipc_zmq: {e}"
)
return False

Expand Down Expand Up @@ -222,6 +210,13 @@ def _load_model_weights(weights, model_runner):

return True

def cleanup(self) -> None:
"""Shutdown and cleanup resources."""
# Close ZMQ socket and context if they exist
if hasattr(self, "zmq_socket"):
self.zmq_socket.close()
self.zmq_context.term()

def start_gpu_profiling(self) -> None:
"""Start GPU profiling."""
torch.cuda.profiler.start()
Expand Down
47 changes: 12 additions & 35 deletions nemo_rl/models/generation/vllm/vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,49 +763,26 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None:
# Wait for all futures to complete
ray.get(futures)

def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool:
"""Update weights of the policy using IPC handles, considering tensor parallelism.

For tp > 1, only the leader in each tensor parallel tied worker group will update weights.

Args:
ipc_handles (dict): Dictionary mapping device UUIDs (str) to parameter IPC handles.

Returns:
bool: True if weights were successfully updated, False otherwise.
"""
def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]:
"""Update weights of the policy using IPC handles via ZMQ socket."""
if not self.worker_group or not self.worker_group.workers:
return False
raise RuntimeError("Worker group is not initialized")

# Choose the appropriate method based on async_engine setting
method_name = (
"update_weights_from_ipc_handles_async"
"update_weights_via_ipc_zmq_async"
if self.cfg["vllm_cfg"]["async_engine"]
else "update_weights_from_ipc_handles"
else "update_weights_via_ipc_zmq"
)

# Only send the ipc handles required by the current worker
ipc_handles_list = []
for worker_device_uuids in self.device_uuids:
worker_ipc_handles = {
device_uuid: ipc_handles[device_uuid]
for device_uuid in worker_device_uuids
}
ipc_handles_list.append(worker_ipc_handles)
# Use run_all_workers_single_data since no data needs to be passed
futures = self.worker_group.run_all_workers_single_data(
method_name,
run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"],
)

try:
# Directly pass ipc_handles to the method
futures = self.worker_group.run_all_workers_multiple_data(
method_name,
ipc_handles=ipc_handles_list,
run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"],
)
# Wait for all futures to complete
results = ray.get(futures)
return all(result for result in results if result is not None)
except Exception as e:
print(f"Error during update weights: {e}")
return False
# this function should co-work with lm_policy, so we should wait for all futures to complete outside
return futures

def update_weights_from_collective(self) -> list[ray.ObjectRef]:
"""Update weights of the policy using collective communication."""
Expand Down
Loading
Loading