Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions nemo_rl/models/generation/vllm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,23 @@ def update_weights_from_ipc_handles(self, ipc_handles):
try:
# Get handles for this device
device_uuid = self.report_device_id()
handles = ipc_handles[device_uuid]
deserialized = ipc_handles[device_uuid]
device_id = self.device.index
weights = []

# Process each handle to get the tensor
for name, handle in handles:
func, args = handle
all_handles, key_to_type_and_offset_and_size_in_big_tensor = deserialized
type_to_packed_big_tensor_size = {}
for k, tensor_handle in all_handles:
func, args = tensor_handle
list_args = list(args)
# Update device ID to match the current device
list_args[6] = device_id
tensor = func(*list_args)
weights.append((name, tensor))
type_to_packed_big_tensor_size[k] = tensor

for key, shape, type, offset, size in key_to_type_and_offset_and_size_in_big_tensor:
assert offset+size <= type_to_packed_big_tensor_size[type].numel()
tensor = type_to_packed_big_tensor_size[type][offset:offset+size].clone().reshape(shape)
weights.append((key, tensor))

# Load weights into the model
self.model_runner.model.load_weights(weights=weights)
Expand Down
42 changes: 35 additions & 7 deletions nemo_rl/models/policy/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,18 +1367,46 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]:
from torch.multiprocessing.reductions import reduce_tensor

# Create IPC handles for each parameter
all_handles = []

# pack tensors in gathered_hf_params to a big tensor
type_to_packed_big_tensor_size = defaultdict(lambda : 0)
key_to_type_and_offset_and_size_in_big_tensor = []

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest running this code through cursor to simplify the code, var names, typos..

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes this is so far a draft, i will do some cleanup

for key, tensor in gathered_hf_params.items():
key_to_type_and_offset_and_size_in_big_tensor.append(
(
key,
tensor.shape,
tensor.dtype,
type_to_packed_big_tensor_size[tensor.dtype],
tensor.numel()
)
)
type_to_packed_big_tensor_size[tensor.dtype] += tensor.numel()

type_to_packed_big_tensor_size = {
k: torch.empty(v, device=tensor.device, dtype=k, requires_grad=False)
for k, v in type_to_packed_big_tensor_size.items()
}
for i, (key, tensor) in enumerate(gathered_hf_params.items()):
k, shape, dtype, offset, size = key_to_type_and_offset_and_size_in_big_tensor[i]
assert k == key
type_to_packed_big_tensor_size[dtype][offset:offset+size] = tensor.detach().view(-1)

all_handles = []
for dtype, tensor in type_to_packed_big_tensor_size.items():
handle = reduce_tensor(tensor.detach())
all_handles.append((key, handle))
all_handles.append((dtype, handle))

# Store references to avoid premature garbage collection
self._held_gather_buffer = gathered_hf_params
shapes = {}
for key, tensor in gathered_hf_params.items():
shapes[key] = tensor.shape

return {device_uuid: all_handles}
self._held_gather_buffer = type_to_packed_big_tensor_size
# shapes = {}
# for key, tensor in gathered_hf_params.items():
# shapes[key] = tensor.shape

serielized = (all_handles, key_to_type_and_offset_and_size_in_big_tensor)

return {device_uuid: serielized}

def prepare_for_lp_inference(self):
self.model = self.move_model(self.model, "cuda", move_grads=False)
Expand Down
Loading