Skip to content
14 changes: 8 additions & 6 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ def setup(
# wait for all futures to complete
ray.get(futures_train + futures_inference)

# prepare refit info
state_dict_info = policy.prepare_refit_info()
policy_generation.prepare_refit_info(state_dict_info)

loss_fn = ClippedPGLossFn(loss_config)

print("\n" + "=" * 60)
Expand Down Expand Up @@ -422,17 +426,15 @@ def refit_policy_generation(
# do update
for keys in grouped_param_keys:
ipc_handles = policy.get_weights_ipc_handles(keys)
update_success = policy_generation.update_weights(ipc_handles)
update_success = policy_generation.update_weights_from_ipc_handles(
ipc_handles
)
if not update_success:
break
else:
# prepare info for update weights
state_dict_info = policy.prepare_info_for_collective()
# update weights through nccl
futures_train = policy.broadcast_weights_for_collective()
futures_inference = policy_generation.update_weights_from_collective(
state_dict_info
)
futures_inference = policy_generation.update_weights_from_collective()
# wait for all futures to complete
ray.get(futures_train)
results = ray.get(futures_inference)
Expand Down
10 changes: 6 additions & 4 deletions nemo_rl/models/generation/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,14 @@ def prepare_for_generation(self, *args: Any, **kwargs: Any) -> bool:
def finish_generation(self, *args: Any, **kwargs: Any) -> bool:
pass

def update_weights(self, ipc_handles: dict[str, Any]) -> bool:
def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None:
"""Prepare the info for refit."""
raise NotImplementedError

def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool:
"""Update the model weights from the given IPC handles."""
raise NotImplementedError

def update_weights_from_collective(
self, info: dict[str, Any]
) -> list[ray.ObjectRef]:
def update_weights_from_collective(self) -> list[ray.ObjectRef]:
"""Update the model weights from collective communication."""
raise NotImplementedError
44 changes: 34 additions & 10 deletions nemo_rl/models/generation/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,14 @@ async def report_device_id_async(self) -> list[str]:

return cast(list[str], list_of_worker_results)

def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None:
"""Prepare the info for refit."""
self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,))

async def prepare_refit_info_async(self, state_dict_info: dict[str, Any]) -> None:
"""Async version of prepare_refit_info."""
await self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,))

def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool:
"""Update weights from IPC handles by delegating to the vLLM Worker implementation.

Expand Down Expand Up @@ -1132,7 +1140,7 @@ async def update_weights_from_ipc_handles_async(
traceback.print_exc()
return False

def update_weights_from_collective(self, info: dict[str, Any]) -> bool:
def update_weights_from_collective(self) -> bool:
"""Update the model weights from collective communication."""
try:
assert self.llm is not None, (
Expand All @@ -1145,7 +1153,7 @@ def update_weights_from_collective(self, info: dict[str, Any]) -> bool:
)

result_or_coro = self.llm.collective_rpc(
"update_weights_from_collective", args=(info,)
"update_weights_from_collective", args=tuple()
)
worker_result = result_or_coro[0]

Expand All @@ -1162,7 +1170,7 @@ def update_weights_from_collective(self, info: dict[str, Any]) -> bool:
traceback.print_exc()
return False

async def update_weights_from_collective_async(self, info: dict[str, Any]) -> bool:
async def update_weights_from_collective_async(self) -> bool:
"""Async version of update_weights_from_collective."""
try:
assert self.llm is not None, (
Expand All @@ -1175,7 +1183,7 @@ async def update_weights_from_collective_async(self, info: dict[str, Any]) -> bo
)

result_or_coro = await self.llm.collective_rpc(
"update_weights_from_collective", args=(info,)
"update_weights_from_collective", args=tuple()
)

if asyncio.iscoroutine(result_or_coro):
Expand Down Expand Up @@ -1908,7 +1916,26 @@ def shutdown(self) -> bool:
print(f"Error during policy shutdown: {e}")
return False

def update_weights(self, ipc_handles: dict[str, Any]) -> bool:
def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None:
"""Prepare the info for refit."""
# Choose the appropriate method based on async_engine setting
method_name = (
"prepare_refit_info_async"
if self.cfg["vllm_cfg"]["async_engine"]
else "prepare_refit_info"
)

# Use run_all_workers_single_data to send data to all workers
futures = self.worker_group.run_all_workers_single_data(
method_name,
state_dict_info=state_dict_info,
run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"],
)

# Wait for all futures to complete
ray.get(futures)

def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool:
"""Update weights of the policy using IPC handles, considering tensor parallelism.

For tp > 1, only the leader in each tensor parallel tied worker group will update weights.
Expand Down Expand Up @@ -1952,9 +1979,7 @@ def update_weights(self, ipc_handles: dict[str, Any]) -> bool:
print(f"Error during update weights: {e}")
return False

def update_weights_from_collective(
self, info: dict[str, Any]
) -> list[ray.ObjectRef]:
def update_weights_from_collective(self) -> list[ray.ObjectRef]:
"""Update weights of the policy using collective communication."""
if not self.worker_group or not self.worker_group.workers:
raise RuntimeError("Worker group is not initialized")
Expand All @@ -1966,10 +1991,9 @@ def update_weights_from_collective(
else "update_weights_from_collective"
)

# Use run_all_workers_single_data to send data to all workers
# Use run_all_workers_single_data for methods that don't need data
futures = self.worker_group.run_all_workers_single_data(
method_name,
info=info,
run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"],
)

Expand Down
43 changes: 39 additions & 4 deletions nemo_rl/models/generation/vllm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any
from typing import Any, Optional

import torch

Expand Down Expand Up @@ -51,6 +51,21 @@ def report_device_id(self) -> str:

return get_device_uuid(self.device.index)

def prepare_refit_info(
self, state_dict_info: Optional[dict[str, Any]] = None
) -> None:
"""Prepare the info for refit.

DtensorPolicyWorker:
colocated inference: state_dict_info is None
non-colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype)}

MegatronPolicyWorker:
colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype, numel)}
non-colocated inference: not implemented yet
"""
self.state_dict_info = state_dict_info

def update_weights_from_global_ipc_handles(self, global_device_ipc_handles):
"""Update weights from global IPC handles.

Expand Down Expand Up @@ -84,6 +99,11 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles):
weights = []

if is_tensor_packed:
assert self.state_dict_info is not None, (
"state_dict_info is not prepared. "
"Please call prepare_refit_info when initializing the worker."
)

# Extract packed tensor from IPC handle
dtype_to_packed_tensor = {}
for dtype, tensor_handle in all_handles:
Expand All @@ -95,7 +115,17 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles):

# Unpack tensor to weights. Here we only return a view of the tensor to avoid
# using extra memory.
for key, (shape, dtype, offset, size) in tensor_metadata.items():
for key, metadata in tensor_metadata.items():
# dtype for the 1st and 2nd steps may be different (e.g. e_score_correction_bias)
if isinstance(metadata, tuple):
# use dtype of current step
offset, dtype = metadata
shape, _, size = self.state_dict_info[key]
# update record
self.state_dict_info[key] = (shape, dtype, size)
else:
offset = metadata
shape, dtype, size = self.state_dict_info[key]
tensor = dtype_to_packed_tensor[dtype][offset : offset + size].view(
*shape
)
Expand All @@ -118,10 +148,15 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles):
)
return False

def update_weights_from_collective(self, info: dict[str, Any]) -> bool:
def update_weights_from_collective(self) -> bool:
"""Update the model weights from collective communication."""
assert self.state_dict_info is not None, (
"state_dict_info is not prepared. "
"Please call prepare_refit_info when initializing the worker."
)

try:
for name, (shape, dtype) in info.items():
for name, (shape, dtype) in self.state_dict_info.items():
weight = torch.empty(shape, dtype=dtype, device="cuda")
self.model_update_group.broadcast(weight, src=0)
self.model_runner.model.load_weights(weights=[(name, weight)])
Expand Down
Loading
Loading