Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/_automodel_integration_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ jobs:
echo "Checking if dtensor policy worker files are synchronized..."

# Define the dtensor policy worker file paths
DTENSOR_POLICY_WORKER_FILE="nemo_rl/models/policy/dtensor_policy_worker.py"
DTENSOR_POLICY_WORKER_V2_FILE="nemo_rl/models/policy/dtensor_policy_worker_v2.py"
DTENSOR_POLICY_WORKER_FILE="nemo_rl/models/policy/workers/dtensor_policy_worker.py"
DTENSOR_POLICY_WORKER_V2_FILE="nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py"

# Check if dtensor_policy_worker.py was modified in this PR
if git diff --name-only origin/${{ inputs.base_ref }}..HEAD | grep -q "^${DTENSOR_POLICY_WORKER_FILE}$"; then
Expand Down
10 changes: 5 additions & 5 deletions docs/fp8.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ FP8 generations are recommended to be configured with the following settings:
use_activation_pow2_scale: False
```

"To train with FP8, you need to set the Megatron path and configure it using the following settings:
To train with FP8, you need to set the Megatron path and configure it using the following settings:

```
policy:
Expand All @@ -68,12 +68,12 @@ FP8 generations are recommended to be configured with the following settings:

The TransformerEngine implementation for this recipe requires **cuda version ≥ 12.9**. The latest nemo-rl depends on torch 2.8.0 + cuda 12.9 (since this [commit](https://github.com/NVIDIA-NeMo/RL/commit/3f36d14b53e906b27c01c06e36dbbd2b8eb300cd)). Users should check-out code to latest and build container from `docker/Dockerfile` ([instructions](docker.md)).

If you are using nemo-rl before this [commit](https://github.com/NVIDIA-NeMo/RL/commit/3f36d14b53e906b27c01c06e36dbbd2b8eb300cd), you will see the following error when trying to use fp8 training
If you are using nemo-rl before this [commit](https://github.com/NVIDIA-NeMo/RL/commit/3f36d14b53e906b27c01c06e36dbbd2b8eb300cd), you will see the following error when trying to use fp8 training:

```
File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/transformer_engine/pytorch/fp8.py", line 646, in fp8_autocast
File "/opt/ray_venvs/nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/transformer_engine/pytorch/fp8.py", line 646, in fp8_autocast
FP8GlobalStateManager.fp8_autocast_enter(
File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/transformer_engine/pytorch/fp8.py", line 465, in fp8_autocast_enter
File "/opt/ray_venvs/nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/transformer_engine/pytorch/fp8.py", line 465, in fp8_autocast_enter
assert fp8_block_available, reason_for_no_fp8_block
^^^^^^^^^^^^^^^^^^^
AssertionError: FP8 block scaled GEMM requires Hopper and CUDA >= 12.9.
Expand All @@ -88,5 +88,5 @@ The above results are from Llama-3.1-8B-Instruct GRPO experiments. You can run t
* For BF16: `examples/configs/grpo_math_8B_megatron.yaml`
* For FP8: `examples/configs/grpo_math_8B_megatron_fp8.yaml`

In the experiment in this figure, enabling FP8 rollout and training gives 15%-25% decrease in step time, and the validation accuracy curves match up to 1000 step.
In the experiment in this figure, enabling FP8 rollout and training gives 15%-25% decrease in step time, and the validation accuracy curves match up to 1000 steps.
Efforts are ongoing to performs longer runs and further optimize performance.
6 changes: 3 additions & 3 deletions nemo_rl/distributed/ray_actor_environment_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
"nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker": VLLM_EXECUTABLE,
# Temporary workaround for the coupled implementation of DTensorPolicyWorker and vLLM.
# This will be reverted to PY_EXECUTABLES.BASE once https://github.com/NVIDIA-NeMo/RL/issues/501 is resolved.
"nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker": VLLM_EXECUTABLE,
"nemo_rl.models.policy.dtensor_policy_worker_v2.DTensorPolicyWorkerV2": PY_EXECUTABLES.AUTOMODEL,
"nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker": MCORE_EXECUTABLE,
"nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker": VLLM_EXECUTABLE,
"nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2": PY_EXECUTABLES.AUTOMODEL,
"nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker": MCORE_EXECUTABLE,
"nemo_rl.environments.math_environment.MathEnvironment": PY_EXECUTABLES.SYSTEM,
"nemo_rl.environments.vlm_environment.VLMEnvironment": PY_EXECUTABLES.SYSTEM,
"nemo_rl.environments.code_environment.CodeEnvironment": PY_EXECUTABLES.SYSTEM,
Expand Down
9 changes: 9 additions & 0 deletions nemo_rl/models/policy/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def train(
data: BatchedDataDict,
loss_fn: LossFunction,
eval_mode: bool = False,
*
gbs: Optional[int] = None,
mbs: Optional[int] = None,
) -> dict[str, Any]:
Expand Down Expand Up @@ -156,6 +157,10 @@ def finish_training(self, *args: Any, **kwargs: Any) -> None:
def save_checkpoint(self, *args: Any, **kwargs: Any) -> None:
pass

@abstractmethod
def load_checkpoint(self, *args: Any, **kwargs: Any) -> None:
pass

@abstractmethod
def shutdown(self) -> bool:
pass
Expand Down Expand Up @@ -191,3 +196,7 @@ def broadcast_weights_for_collective(
self, kv_scales: Optional[dict[str, float]] = None
) -> list[ray.ObjectRef]:
pass

@abstractmethod
def prepare_for_lp_inference(self) -> None:
pass
6 changes: 3 additions & 3 deletions nemo_rl/models/policy/lm_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
)
if megatron_enable:
worker_builder_cls = (
"nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker"
"nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker"
)
tp_size = config["megatron_cfg"]["tensor_model_parallel_size"]
pp_size = config["megatron_cfg"]["pipeline_model_parallel_size"]
Expand All @@ -112,10 +112,10 @@ def __init__(
# Check if _v2 is enabled in dtensor_cfg (defaults to False for backward compatibility)
use_v2 = config.get("dtensor_cfg", {}).get("_v2", False)
if use_v2:
worker_builder_cls = "nemo_rl.models.policy.dtensor_policy_worker_v2.DTensorPolicyWorkerV2"
worker_builder_cls = "nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2"
else:
worker_builder_cls = (
"nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker"
"nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker"
)

tp_size = config["dtensor_cfg"]["tensor_parallel_size"]
Expand Down
135 changes: 135 additions & 0 deletions nemo_rl/models/policy/workers/base_policy_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ray
import torch
import zmq
from typing import Any, Optional

Comment thread
ashors1 marked this conversation as resolved.
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.models.policy.interfaces import ReferenceLogprobOutputSpec


class AbstractPolicyWorker:
"""Base class for policy workers with shared functionality."""

def init_collective(
self, ip: str, port: int, world_size: int, *, train_world_size: int
) -> None:
"""Initialize the collective communication.

Args:
ip: IP address for the process group
port: Port for the process group
world_size: Total world size (train_world_size + inference_world_size)
train_world_size: Number of training workers (used in inference cluster)
"""
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup

pg = StatelessProcessGroup.create(
host=ip, port=port, rank=self.rank, world_size=world_size
)
device = torch.cuda.current_device()
self.model_update_group = PyNcclCommunicator(pg, device=device)

def is_alive(self) -> bool:
"""Check if the worker is alive."""
return True

def reset_peak_memory_stats(self) -> None:
"""Reset peak memory statistics."""
torch.cuda.reset_peak_memory_stats()

def get_gpu_info(self) -> dict[str, Any]:
"""Return information about the GPU being used by this worker."""
from nemo_rl.models.policy.utils import get_gpu_info
return get_gpu_info(self.model)

def report_device_id(self) -> str:
"""Report the UUID of the current CUDA device using NVML.
Returns:
str: UUID of the device in the format "GPU-xxxxx"
"""
from nemo_rl.utils.nvml import get_device_uuid

# Get current device index from torch
device_idx = torch.cuda.current_device()
# Get device UUID using NVML
return get_device_uuid(device_idx)

def get_zmq_address(self) -> str:
"""Get the ZMQ address for the current device."""
return f"ipc:///tmp/{self.report_device_id()}.sock"

def maybe_init_zmq(self) -> None:
"""Initialize the ZMQ socket if it doesn't exist."""
if not hasattr(self, "zmq_socket"):
self.zmq_context = zmq.Context()
self.zmq_socket = self.zmq_context.socket(zmq.REQ)
self.zmq_socket.setsockopt(
zmq.SNDTIMEO, 120000
) # set timeout to 120 seconds
self.zmq_socket.setsockopt(
zmq.RCVTIMEO, 120000
) # set timeout to 120 seconds
self.zmq_socket.setsockopt(zmq.LINGER, 0)
self.zmq_socket.bind(self.get_zmq_address())

def get_free_memory_bytes(self) -> int:
"""Get the available free memory."""
from nemo_rl.utils.nvml import get_free_memory_bytes

device_idx = torch.cuda.current_device()
return get_free_memory_bytes(device_idx)

def shutdown(self) -> None:
"""Shutdown the policy."""
# Clean up extension resources like ZMQ sockets
if hasattr(self, "zmq_socket"):
self.zmq_socket.close()
self.zmq_context.term()

def start_gpu_profiling(self) -> None:
"""Start GPU profiling."""
torch.cuda.profiler.start()

def stop_gpu_profiling(self) -> None:
"""Stop GPU profiling."""
torch.cuda.profiler.stop()

def report_node_ip_and_gpu_id(self) -> tuple[str, int]:
"""Report the node IP and GPU ID of the current worker."""
ip = ray._private.services.get_node_ip_address()
gpu_id = ray.get_gpu_ids()[0]
return (ip, gpu_id)

def get_reference_policy_logprobs(
self, *, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None
) -> BatchedDataDict[ReferenceLogprobOutputSpec]:
"""Get the logprobs from the reference policy for a batch of data.
If micro_batch_size is provided, it will be used instead of the configured
logprob_batch_size.
Returns:
a BatchedDataDict with key "reference_logprobs" and shape [batch_size, sequence_length].
We use the convention that the logprob of the first token is 0 so that the sequence length is maintained.
The logprob of input token i is specified at position i in the output logprobs tensor.
"""
with self.use_reference_model():
reference_logprobs = self.get_logprobs(
data=data, micro_batch_size=micro_batch_size
)

return_data = BatchedDataDict[ReferenceLogprobOutputSpec]()
return_data["reference_logprobs"] = reference_logprobs["logprobs"].cpu()
return return_data
Comment thread
ashors1 marked this conversation as resolved.
Comment thread
ashors1 marked this conversation as resolved.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
LogprobOutputSpec,
ReferenceLogprobOutputSpec,
ScoreOutputSpec,
ColocatablePolicyInterface,
)
from nemo_rl.models.policy.utils import (
configure_dynamo_cache,
Expand All @@ -85,6 +86,7 @@
import_class_from_path,
resolve_model_class,
)
from nemo_rl.models.policy.workers.base_policy_worker import AbstractPolicyWorker
from nemo_rl.utils.automodel_checkpoint import (
load_checkpoint,
save_checkpoint,
Expand All @@ -97,7 +99,7 @@
@ray.remote(
runtime_env=get_runtime_env_for_policy_worker("dtensor_policy_worker_v2")
) # pragma: no cover
class DTensorPolicyWorkerV2:
class DTensorPolicyWorkerV2(ColocatablePolicyInterface, AbstractPolicyWorker):
def __repr__(self) -> str:
"""Customizes the actor's prefix in the Ray logs.

Expand Down Expand Up @@ -473,9 +475,6 @@ def init_collective(
device = torch.cuda.current_device()
self.model_update_group = PyNcclCommunicator(pg, device=device)

def is_alive(self) -> bool:
return True

def check_model_allow_flash_attn_args(self, model_config) -> bool:
# Some models doesn't support flash_attn_kwargs
# Check nemotron nas.
Expand All @@ -487,13 +486,6 @@ def check_model_allow_flash_attn_args(self, model_config) -> bool:

return True

def reset_peak_memory_stats(self) -> None:
torch.cuda.reset_peak_memory_stats()

def get_gpu_info(self) -> dict[str, Any]:
"""Return information about the GPU being used by this worker."""
return get_gpu_info(self.model)

@wrap_with_nvtx_name("dtensor_policy_worker_v2/train")
def train(
self,
Expand Down Expand Up @@ -1615,24 +1607,6 @@ def use_reference_model(self) -> Generator[None, None, None]:
val = to_local_if_dtensor(v)
val.copy_(curr_state_dict[k])

@wrap_with_nvtx_name("dtensor_policy_worker_v2/get_reference_policy_logprobs")
def get_reference_policy_logprobs(
self, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None
) -> BatchedDataDict[ReferenceLogprobOutputSpec]:
"""Get the logprobs from the reference policy for a batch of data.

Returns:
a BatchedDataDict with key "reference_logprobs" and shape [batch_size, sequence_length].
We use the convention that the logprob of the first token is 0 so that the sequence length is maintained.
The logprob of input token i is specified at position i in the output logprobs tensor.
"""
with self.use_reference_model():
reference_logprobs = self.get_logprobs(data, micro_batch_size)

return_data = BatchedDataDict[ReferenceLogprobOutputSpec]()
return_data["reference_logprobs"] = reference_logprobs["logprobs"].cpu()
return return_data

def _add_noise_to_weights(self) -> None:
"""Add small Gaussian noise to the weights of the model. Note that this is used for testing purposes only."""
noise_std = 0.01 # Standard deviation for the noise
Expand All @@ -1653,37 +1627,6 @@ def return_model_config(self) -> dict[str, Any]:
"""
return self.model.config

def report_device_id(self) -> str:
"""Report the UUID of the current CUDA device using NVML.

Returns:
str: UUID of the device in the format "GPU-xxxxx"
"""
from nemo_rl.utils.nvml import get_device_uuid

# Get current device index from torch
device_idx = torch.cuda.current_device()
# Get device UUID using NVML
return get_device_uuid(device_idx)

def get_zmq_address(self):
"""Get the ZMQ address for the current device."""
return f"ipc:///tmp/{self.report_device_id()}.sock"

def maybe_init_zmq(self):
"""Initialize the ZMQ socket if it doesn't exist."""
if not hasattr(self, "zmq_socket"):
self.zmq_context = zmq.Context()
self.zmq_socket = self.zmq_context.socket(zmq.REQ)
self.zmq_socket.setsockopt(
zmq.SNDTIMEO, 120000
) # set timeout to 120 seconds
self.zmq_socket.setsockopt(
zmq.RCVTIMEO, 120000
) # set timeout to 120 seconds
self.zmq_socket.setsockopt(zmq.LINGER, 0)
self.zmq_socket.bind(self.get_zmq_address())

@torch.no_grad()
def prepare_refit_info(self) -> Optional[dict[str, Any]]:
"""Prepare state dict metadata for weight refitting and IPC streaming."""
Expand All @@ -1694,13 +1637,6 @@ def prepare_refit_info(self) -> Optional[dict[str, Any]]:

return state_dict_info

def get_free_memory_bytes(self) -> int:
"""Get the available free memory."""
from nemo_rl.utils.nvml import get_free_memory_bytes

device_idx = torch.cuda.current_device()
return get_free_memory_bytes(device_idx)

@torch.no_grad()
def calibrate_qkv_fp8_scales(
self,
Expand Down Expand Up @@ -1949,25 +1885,4 @@ def load_checkpoint(
optimizer=self.optimizer if optimizer_path else None,
scheduler=self.scheduler if optimizer_path else None,
optimizer_path=optimizer_path,
)

def shutdown(self) -> None:
"""Shutdown the policy."""
# Clean up extension resources like ZMQ sockets
if hasattr(self, "zmq_socket"):
self.zmq_socket.close()
self.zmq_context.term()

def start_gpu_profiling(self) -> None:
"""Start GPU profiling."""
torch.cuda.profiler.start()

def stop_gpu_profiling(self) -> None:
"""Stop GPU profiling."""
torch.cuda.profiler.stop()

def report_node_ip_and_gpu_id(self) -> list[tuple[str, int]]:
"""Report the node IP and GPU ID of the current worker."""
ip = ray._private.services.get_node_ip_address()
gpu_id = ray.get_gpu_ids()[0]
return (ip, gpu_id)
)
Loading
Loading