Skip to content
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/NeMo-workspace/NeMo
Submodule NeMo updated from 33259f to 8ddf43
14 changes: 8 additions & 6 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ def setup(
# wait for all futures to complete
ray.get(futures_train + futures_inference)

# prepare refit info
state_dict_info = policy.prepare_refit_info()
policy_generation.prepare_refit_info(state_dict_info)

loss_fn = ClippedPGLossFn(loss_config)

print("\n" + "=" * 60)
Expand Down Expand Up @@ -422,17 +426,15 @@ def refit_policy_generation(
# do update
for keys in grouped_param_keys:
ipc_handles = policy.get_weights_ipc_handles(keys)
update_success = policy_generation.update_weights(ipc_handles)
update_success = policy_generation.update_weights_from_ipc_handles(
ipc_handles
)
if not update_success:
break
else:
# prepare info for update weights
state_dict_info = policy.prepare_info_for_collective()
# update weights through nccl
futures_train = policy.broadcast_weights_for_collective()
futures_inference = policy_generation.update_weights_from_collective(
state_dict_info
)
futures_inference = policy_generation.update_weights_from_collective()
# wait for all futures to complete
ray.get(futures_train)
results = ray.get(futures_inference)
Expand Down
10 changes: 6 additions & 4 deletions nemo_rl/models/generation/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,14 @@ def prepare_for_generation(self, *args: Any, **kwargs: Any) -> bool:
def finish_generation(self, *args: Any, **kwargs: Any) -> bool:
pass

def update_weights(self, ipc_handles: dict[str, Any]) -> bool:
def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None:
"""Prepare the info for refit."""
raise NotImplementedError

def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool:
"""Update the model weights from the given IPC handles."""
raise NotImplementedError

def update_weights_from_collective(
self, info: dict[str, Any]
) -> list[ray.ObjectRef]:
def update_weights_from_collective(self) -> list[ray.ObjectRef]:
"""Update the model weights from collective communication."""
raise NotImplementedError
44 changes: 34 additions & 10 deletions nemo_rl/models/generation/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,14 @@ async def report_device_id_async(self) -> list[str]:

return cast(list[str], list_of_worker_results)

def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None:
"""Prepare the info for refit."""
self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,))

async def prepare_refit_info_async(self, state_dict_info: dict[str, Any]) -> None:
"""Async version of prepare_refit_info."""
await self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,))

def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool:
"""Update weights from IPC handles by delegating to the vLLM Worker implementation.

Expand Down Expand Up @@ -1132,7 +1140,7 @@ async def update_weights_from_ipc_handles_async(
traceback.print_exc()
return False

def update_weights_from_collective(self, info: dict[str, Any]) -> bool:
def update_weights_from_collective(self) -> bool:
"""Update the model weights from collective communication."""
try:
assert self.llm is not None, (
Expand All @@ -1145,7 +1153,7 @@ def update_weights_from_collective(self, info: dict[str, Any]) -> bool:
)

result_or_coro = self.llm.collective_rpc(
"update_weights_from_collective", args=(info,)
"update_weights_from_collective", args=tuple()
)
worker_result = result_or_coro[0]

Expand All @@ -1162,7 +1170,7 @@ def update_weights_from_collective(self, info: dict[str, Any]) -> bool:
traceback.print_exc()
return False

async def update_weights_from_collective_async(self, info: dict[str, Any]) -> bool:
async def update_weights_from_collective_async(self) -> bool:
"""Async version of update_weights_from_collective."""
try:
assert self.llm is not None, (
Expand All @@ -1175,7 +1183,7 @@ async def update_weights_from_collective_async(self, info: dict[str, Any]) -> bo
)

result_or_coro = await self.llm.collective_rpc(
"update_weights_from_collective", args=(info,)
"update_weights_from_collective", args=tuple()
)

if asyncio.iscoroutine(result_or_coro):
Expand Down Expand Up @@ -1908,7 +1916,26 @@ def shutdown(self) -> bool:
print(f"Error during policy shutdown: {e}")
return False

def update_weights(self, ipc_handles: dict[str, Any]) -> bool:
def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None:
"""Prepare the info for refit."""
# Choose the appropriate method based on async_engine setting
method_name = (
"prepare_refit_info_async"
if self.cfg["vllm_cfg"]["async_engine"]
else "prepare_refit_info"
)

# Use run_all_workers_single_data to send data to all workers
futures = self.worker_group.run_all_workers_single_data(
method_name,
state_dict_info=state_dict_info,
run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"],
)

# Wait for all futures to complete
ray.get(futures)

def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool:
"""Update weights of the policy using IPC handles, considering tensor parallelism.

For tp > 1, only the leader in each tensor parallel tied worker group will update weights.
Expand Down Expand Up @@ -1952,9 +1979,7 @@ def update_weights(self, ipc_handles: dict[str, Any]) -> bool:
print(f"Error during update weights: {e}")
return False

def update_weights_from_collective(
self, info: dict[str, Any]
) -> list[ray.ObjectRef]:
def update_weights_from_collective(self) -> list[ray.ObjectRef]:
"""Update weights of the policy using collective communication."""
if not self.worker_group or not self.worker_group.workers:
raise RuntimeError("Worker group is not initialized")
Expand All @@ -1966,10 +1991,9 @@ def update_weights_from_collective(
else "update_weights_from_collective"
)

# Use run_all_workers_single_data to send data to all workers
# Use run_all_workers_single_data for methods that don't need data
futures = self.worker_group.run_all_workers_single_data(
method_name,
info=info,
run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"],
)

Expand Down
45 changes: 39 additions & 6 deletions nemo_rl/models/generation/vllm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any
from typing import Any, Optional

import torch
from torch.multiprocessing.reductions import rebuild_cuda_tensor

try:
import vllm # noqa: F401
Expand Down Expand Up @@ -51,6 +52,21 @@ def report_device_id(self) -> str:

return get_device_uuid(self.device.index)

def prepare_refit_info(
self, state_dict_info: Optional[dict[str, Any]] = None
) -> None:
"""Prepare the info for refit.

DtensorPolicyWorker:
colocated inference: state_dict_info is None
non-colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype)}

MegatronPolicyWorker:
colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype, numel)}
non-colocated inference: not implemented yet
"""
self.state_dict_info = state_dict_info

def update_weights_from_global_ipc_handles(self, global_device_ipc_handles):
"""Update weights from global IPC handles.

Expand Down Expand Up @@ -87,23 +103,35 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles):
# Extract packed tensor from IPC handle
dtype_to_packed_tensor = {}
for dtype, tensor_handle in all_handles:
func, args = tensor_handle
func = rebuild_cuda_tensor
args = tensor_handle[0]
list_args = list(args)
list_args[6] = device_id
tensor = func(*list_args)
dtype_to_packed_tensor[dtype] = tensor

# Unpack tensor to weights. Here we only return a view of the tensor to avoid
# using extra memory.
for key, (shape, dtype, offset, size) in tensor_metadata.items():
for key, metadata in tensor_metadata.items():
# dtype for the 1st and 2nd steps may be different (e.g. e_score_correction_bias)
if isinstance(metadata, tuple):
# use dtype of current step
offset, dtype = metadata
shape, _, size = self.state_dict_info[key]
# update record
self.state_dict_info[key] = (shape, dtype, size)
else:
offset = metadata
shape, dtype, size = self.state_dict_info[key]
tensor = dtype_to_packed_tensor[dtype][offset : offset + size].view(
*shape
)
weights.append((key, tensor))
else:
# Process each handle to get the tensor
for name, handle in name_and_handle_list:
func, args = handle
func = rebuild_cuda_tensor
args = handle[0]
list_args = list(args)
list_args[6] = device_id
tensor = func(*list_args)
Expand All @@ -118,10 +146,15 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles):
)
return False

def update_weights_from_collective(self, info: dict[str, Any]) -> bool:
def update_weights_from_collective(self) -> bool:
"""Update the model weights from collective communication."""
assert self.state_dict_info is not None, (
"state_dict_info is not prepared. "
"Please call prepare_refit_info when initializing the worker."
)

try:
for name, (shape, dtype) in info.items():
for name, (shape, dtype) in self.state_dict_info.items():
weight = torch.empty(shape, dtype=dtype, device="cuda")
self.model_update_group.broadcast(weight, src=0)
self.model_runner.model.load_weights(weights=[(name, weight)])
Expand Down
Loading
Loading