-
Notifications
You must be signed in to change notification settings - Fork 452
refactor: Introduce BasePolicyWorker #1585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 8 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
d1b92f6
introduce BasePolicyWorker
ashors1 a828cd2
Merge branch 'main' of github.com:NVIDIA-NeMo/RL into ashors/base-pol…
ashors1 547f360
fix imports
ashors1 d8054fb
address comments
ashors1 6ec6706
BasePolicyWorker --> AbstractPolicyWorker
ashors1 da556b8
add copyright header
ashors1 7722704
address comments
ashors1 39c5f36
Merge branch 'main' of github.com:NVIDIA-NeMo/RL into ashors/base-pol…
ashors1 a93db16
refactor dtensor policy worker v1
ashors1 8c975d2
Merge branch 'main' of github.com:NVIDIA-NeMo/RL into ashors/base-pol…
ashors1 678e0ee
small bug fixes
ashors1 81b085b
lint
ashors1 6a2ee7c
lint
ashors1 0b7e18b
fix unit tests
ashors1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,135 @@ | ||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import ray | ||
| import torch | ||
| import zmq | ||
| from typing import Any, Optional | ||
|
|
||
| from nemo_rl.distributed.batched_data_dict import BatchedDataDict | ||
| from nemo_rl.models.policy.interfaces import ReferenceLogprobOutputSpec | ||
|
|
||
|
|
||
| class AbstractPolicyWorker: | ||
| """Base class for policy workers with shared functionality.""" | ||
|
|
||
| def init_collective( | ||
| self, ip: str, port: int, world_size: int, *, train_world_size: int | ||
| ) -> None: | ||
| """Initialize the collective communication. | ||
|
|
||
| Args: | ||
| ip: IP address for the process group | ||
| port: Port for the process group | ||
| world_size: Total world size (train_world_size + inference_world_size) | ||
| train_world_size: Number of training workers (used in inference cluster) | ||
| """ | ||
| from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator | ||
| from vllm.distributed.utils import StatelessProcessGroup | ||
|
|
||
| pg = StatelessProcessGroup.create( | ||
| host=ip, port=port, rank=self.rank, world_size=world_size | ||
| ) | ||
| device = torch.cuda.current_device() | ||
| self.model_update_group = PyNcclCommunicator(pg, device=device) | ||
|
|
||
| def is_alive(self) -> bool: | ||
| """Check if the worker is alive.""" | ||
| return True | ||
|
|
||
| def reset_peak_memory_stats(self) -> None: | ||
| """Reset peak memory statistics.""" | ||
| torch.cuda.reset_peak_memory_stats() | ||
|
|
||
| def get_gpu_info(self) -> dict[str, Any]: | ||
| """Return information about the GPU being used by this worker.""" | ||
| from nemo_rl.models.policy.utils import get_gpu_info | ||
| return get_gpu_info(self.model) | ||
|
|
||
| def report_device_id(self) -> str: | ||
| """Report the UUID of the current CUDA device using NVML. | ||
| Returns: | ||
| str: UUID of the device in the format "GPU-xxxxx" | ||
| """ | ||
| from nemo_rl.utils.nvml import get_device_uuid | ||
|
|
||
| # Get current device index from torch | ||
| device_idx = torch.cuda.current_device() | ||
| # Get device UUID using NVML | ||
| return get_device_uuid(device_idx) | ||
|
|
||
| def get_zmq_address(self) -> str: | ||
| """Get the ZMQ address for the current device.""" | ||
| return f"ipc:///tmp/{self.report_device_id()}.sock" | ||
|
|
||
| def maybe_init_zmq(self) -> None: | ||
| """Initialize the ZMQ socket if it doesn't exist.""" | ||
| if not hasattr(self, "zmq_socket"): | ||
| self.zmq_context = zmq.Context() | ||
| self.zmq_socket = self.zmq_context.socket(zmq.REQ) | ||
| self.zmq_socket.setsockopt( | ||
| zmq.SNDTIMEO, 120000 | ||
| ) # set timeout to 120 seconds | ||
| self.zmq_socket.setsockopt( | ||
| zmq.RCVTIMEO, 120000 | ||
| ) # set timeout to 120 seconds | ||
| self.zmq_socket.setsockopt(zmq.LINGER, 0) | ||
| self.zmq_socket.bind(self.get_zmq_address()) | ||
|
|
||
| def get_free_memory_bytes(self) -> int: | ||
| """Get the available free memory.""" | ||
| from nemo_rl.utils.nvml import get_free_memory_bytes | ||
|
|
||
| device_idx = torch.cuda.current_device() | ||
| return get_free_memory_bytes(device_idx) | ||
|
|
||
| def shutdown(self) -> None: | ||
| """Shutdown the policy.""" | ||
| # Clean up extension resources like ZMQ sockets | ||
| if hasattr(self, "zmq_socket"): | ||
| self.zmq_socket.close() | ||
| self.zmq_context.term() | ||
|
|
||
| def start_gpu_profiling(self) -> None: | ||
| """Start GPU profiling.""" | ||
| torch.cuda.profiler.start() | ||
|
|
||
| def stop_gpu_profiling(self) -> None: | ||
| """Stop GPU profiling.""" | ||
| torch.cuda.profiler.stop() | ||
|
|
||
| def report_node_ip_and_gpu_id(self) -> tuple[str, int]: | ||
| """Report the node IP and GPU ID of the current worker.""" | ||
| ip = ray._private.services.get_node_ip_address() | ||
| gpu_id = ray.get_gpu_ids()[0] | ||
| return (ip, gpu_id) | ||
|
|
||
| def get_reference_policy_logprobs( | ||
| self, *, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None | ||
| ) -> BatchedDataDict[ReferenceLogprobOutputSpec]: | ||
| """Get the logprobs from the reference policy for a batch of data. | ||
| If micro_batch_size is provided, it will be used instead of the configured | ||
| logprob_batch_size. | ||
| Returns: | ||
| a BatchedDataDict with key "reference_logprobs" and shape [batch_size, sequence_length]. | ||
| We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. | ||
| The logprob of input token i is specified at position i in the output logprobs tensor. | ||
| """ | ||
| with self.use_reference_model(): | ||
| reference_logprobs = self.get_logprobs( | ||
| data=data, micro_batch_size=micro_batch_size | ||
| ) | ||
|
|
||
| return_data = BatchedDataDict[ReferenceLogprobOutputSpec]() | ||
| return_data["reference_logprobs"] = reference_logprobs["logprobs"].cpu() | ||
| return return_data | ||
|
ashors1 marked this conversation as resolved.
|
||
|
ashors1 marked this conversation as resolved.
|
File renamed without changes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.