Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,15 @@ policy:
top_k: null
stop_token_ids: null
stop_strings: null
mcore_generation_config:
buffer_size_gb: 20 # Total GPU memory (in GB) allocated for KV cache buffers
buffer_guaranteed_fraction: 0.1 # Fraction of buffer reserved for guaranteed active requests
num_cuda_graphs: 16 # Number of CUDA graphs to pre-compile for different batch sizes
block_size_tokens: 256 # Size of each KV cache block in tokens (affects memory granularity)
use_cuda_graphs_for_non_decode_steps: true # Enable CUDA graphs for prefill/context processing
enable_chunked_prefill: true # Split long prefills into chunks for better memory management
unified_memory_level: 0 # Unified memory usage level (0=disabled, higher values enable more aggressive paging)
max_tokens: 16384 # Maximum number of tokens to use in a single step. Analogous to vllm's max_num_batched_tokens
vllm_cfg:
async_engine: false
precision: ${policy.precision}
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/grpo_math_1B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ policy:
use_cuda_graphs_for_non_decode_steps: true # Enable CUDA graphs for prefill/context processing
enable_chunked_prefill: true # Split long prefills into chunks for better memory management
unified_memory_level: 0 # Unified memory usage level (0=disabled, higher values enable more aggressive paging)
max_tokens: 16384 # Maximum number of tokens to use in a single step
max_tokens: 16384 # Maximum number of tokens to use in a single step. Analogous to vllm's max_num_batched_tokens

vllm_cfg:
tensor_parallel_size: 1
Expand Down
53 changes: 37 additions & 16 deletions nemo_rl/models/policy/workers/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from collections import defaultdict
from contextlib import AbstractContextManager, contextmanager, nullcontext
from functools import partial
from typing import Any, Iterator, Optional, TypeVar, cast
from typing import Any, Iterator, Optional, TypedDict, TypeVar, cast

import ray
import torch
Expand Down Expand Up @@ -145,6 +145,27 @@
TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase)


class MegatronGenerationConfig(TypedDict):
# Total GPU memory (in GB) allocated for KV cache buffers
buffer_size_gb: int
# Fraction of buffer reserved for guaranteed active requests
buffer_guaranteed_fraction: float
# Number of CUDA graphs to pre-compile for different batch sizes
num_cuda_graphs: int
# Size of each KV cache block in tokens (affects memory granularity)
block_size_tokens: int
# Enable CUDA graphs for prefill/context processing
use_cuda_graphs_for_non_decode_steps: bool
# Split long prefills into chunks for better memory management
enable_chunked_prefill: bool
# Unified memory usage level (0=disabled, higher values enable more aggressive paging)
unified_memory_level: int
# Maximum number of tokens to use in a single step. Analogous to vllm's max_num_batched_tokens.
# Can cause OOM if set too high so should be tuned with buffer_size_gb if OOMing. If set too
# low, then will only do 512 tokens at a time, which can be slow.
max_tokens: int


def broadcast_object_across_pp_ranks(obj):
"""Broadcast an object across pipeline parallel ranks.

Expand Down Expand Up @@ -1820,22 +1841,22 @@ def generate(
)
from megatron.core.inference.sampling_params import SamplingParams

mcore_generation_config = self.cfg["generation"]["mcore_generation_config"]
buffer_size_gb = mcore_generation_config.get("buffer_size_gb", 20)

num_cuda_graphs = mcore_generation_config.get("num_cuda_graphs", 16)
block_size_tokens = mcore_generation_config.get("block_size_tokens", 256)
use_cuda_graphs_for_non_decode_steps = mcore_generation_config.get(
"use_cuda_graphs_for_non_decode_steps", True
)
enable_chunked_prefill = mcore_generation_config.get(
"enable_chunked_prefill", True
mcore_generation_config = cast(
MegatronGenerationConfig, self.cfg["generation"]["mcore_generation_config"]
)
unified_memory_level = mcore_generation_config.get("unified_memory_level", 0)
buffer_guaranteed_fraction = mcore_generation_config.get(
"buffer_guaranteed_fraction", 0.1
)
max_tokens = mcore_generation_config.get("max_tokens", 16384)
buffer_size_gb = mcore_generation_config["buffer_size_gb"]

num_cuda_graphs = mcore_generation_config["num_cuda_graphs"]
block_size_tokens = mcore_generation_config["block_size_tokens"]
use_cuda_graphs_for_non_decode_steps = mcore_generation_config[
"use_cuda_graphs_for_non_decode_steps"
]
enable_chunked_prefill = mcore_generation_config["enable_chunked_prefill"]
unified_memory_level = mcore_generation_config["unified_memory_level"]
buffer_guaranteed_fraction = mcore_generation_config[
"buffer_guaranteed_fraction"
]
max_tokens = mcore_generation_config["max_tokens"]

model_config = self.model.config
model_config.cuda_graph_impl = "local"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS

# Only run metrics if the target step is reached
if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
# total_step_time observed around ~16, so 17.5 for buffer
uv run tests/check_metrics.py $JSON_METRICS \
'median(data["train/token_mult_prob_error"]) < 1.1' \
'data["train/token_mult_prob_error"]["500"] < 1.1' \
'data["train/reward"]["500"] > 0.1' \
'mean(data["timing/train/total_step_time"], -6, -1) < 10.5'
'mean(data["timing/train/total_step_time"], -6, -1) < 17.5'

# Clean up checkpoint directory after successful run to save space.
rm -rf "$CKPT_DIR"
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/models/policy/test_megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def create_megatron_test_config(
"buffer_size_gb": 20,
"buffer_guaranteed_fraction": 0.1,
"num_cuda_graphs": 16,
"block_size_tokens": 256,
"use_cuda_graphs_for_non_decode_steps": True,
"enable_chunked_prefill": True,
"unified_memory_level": 0,
"max_tokens": 16384,
},
"colocated": {
Expand Down
Loading