Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ uv pip install -e '.[dev,test]'
# Use uv run to launch any runs.
# Note that it is recommended to not activate the venv and instead use `uv run` since
# it ensures consistent environment usage across different shells and sessions.
uv run python examples/run_grpo.py
uv run python examples/run_grpo_math.py
```

## Cluster Start
Expand Down
Binary file added docs/assets/actor-wg-worker-vc.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
33 changes: 25 additions & 8 deletions docs/design_docs/design_and_philosophy.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ Online RL requires coordinating a lot of different pieces of software/models

We refer to each of these pieces of software as an **RL Actor**.

[TODO @sahilj Diagram]

Fundamentally, we need to be able to do 4 things between these RL Actors:
- Resource them (provide GPUs/CPUs)
- Isolate them
Expand All @@ -32,6 +30,8 @@ We create composable and hackable abstractions for each layer of the tasks above

By creating a common interface for these 4 tasks, **RL algorithm code looks the same from 1 GPU to 1000 GPUs and does not care about the implementation of each RL Actor (Megatron, HF, Grad student with pen and paper)**

![actor-wg-worker-vc](../assets/actor-wg-worker-vc.png)

### {py:class}`RayVirtualCluster <nemo_reinforcer.distributed.virtual_cluster.RayVirtualCluster>`
VirtualCluster provides a basic abstraction on top of Ray Placement Groups that allow you to section off a part of your compute resources for WorkerGroups to run on as though they had their own cluster. They support running just one WorkerGroup on each VirtualCluster, or *colocation*, where multiple WorkerGroups share resources (i.e running policy training(hf) and generation(vllm) on the same GPUs in-turn).

Expand Down Expand Up @@ -84,12 +84,29 @@ class RayWorkerGroup:
- Support for tied worker groups where multiple workers process the same data
"""
```
[TODO @sahilj Diagram]

`RayWorkerGroup` provides functions like `run_all_workers_single_data` and `run_all_workers_multiple_data` to control and communicate to individual worker processes.


### Single-Controller & Execution Diagram

## Walking through an implementation of GRPO


We control the RL Actors using a single-process head controller. Using the aforementioned abstractions, this allows us to represent the main loop of GRPO as though we were working on 1 GPU
```python
# data processing/transformations between each step omitted
def grpo_train(
policy: PolicyInterface,
policy_generation: GenerationInterface,
environment: EnvironmentInterface,
dataloader: Iterable[BatchedDataDict[DatumSpec]],
):
loss_fn = GRPOLossFn()
for batch in dataloader:
batch.repeat_interleave(num_generations_per_prompt) # repeat for GRPO
generations = policy_generation.generate(batch)
rewards = environment.step(generations)

logprobs = policy.get_logprobs(generations)
reference_logprobs = policy.get_reference_logprobs(generations)

training_data = calculate_grpo_trainnig_data(generations, logprobs, reference_logprobs, rewards)
policy.train(generations, logprobs, reference_logprobs, GRPOLossFn)
```
For a real implementation of grpo (with valiation, checkpointing, memory movement, and the omitted data processing steps), see [grpo_train](../../nemo_reinforcer/algorithms/grpo.py)
94 changes: 92 additions & 2 deletions docs/guides/grpo.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,93 @@
# GRPO
# An in-depth walkthrough of GRPO in Reinforcer

placeholder TBD
## Quickstart: Launch a GRPO Run

If you want to get running quickly, the script [examples/run_grpo_math.py](../../examples/run_grpo_math.py) has an example implementation of using GRPO to train a model on math problems. This script can either be launched locally or via Slurm. For details on how to set up Ray and launch a job using Slurm, refer to the [cluster documentation](../cluster.md).

We recommend launching the job using `uv`:
```bash
uv run examples/run_grpo_math.py --config <PATH TO YAML CONFIG> {overrides}
```
If not specified, `config` will default to [examples/configs/grpo.yaml](../../examples/configs/grpo.yaml)

## Now, for the details:

In this guide, we'll walk through we handle
* Data
* Model training
* Fast generation
* Overall Resource Flow

### Data
We support training with multiple RL "Environments" at the same time.

An [Environment](../../nemo_reinforcer/environments/interfaces.py) is an object that accepts a state/action history and returns an update state and rewards for the step. They run as Ray Remote Actors. Example [MathEnvironment](../../nemo_reinforcer/environments/math_environment.py).

To support this, we need to know:
* What environments you have
* Which data should go to which environments
* How to prepare the data from your dataset into a form we can use

#### Common Data Format
We define a [DatumSpec](../../nemo_reinforcer/data/interfaces.py) that holds all relevant information for each training example:
```python
class DatumSpec(TypedDict):
message_log: LLMMessageLogType
length: int # total (concatenated) length of the message tensors
extra_env_info: Dict[str, Any] # anything your environment requires goes here, for example the 'answer' of a math problem
loss_multiplier: float # multiplier for the loss for this datum. 0 to mask out (say the sample is invalid)
idx: int
task_name: Optional[str] = "default"
__extra__: Any # This allows additional fields of any type
```

#### Data Processors
We name all distinct "environments your model wants to optimize against" "tasks". So you might define a "math" task or a "code" task.
For each task, you should provide a data processor that reads from your dataset and returns a [DatumSpec](../../nemo_reinforcer/data/interfaces.py)

```python
def my_data_processor(
datum_dict: Dict[str, Any], # loaded directly from your dataset (i.e. single line of jsonl data)
task_data_spec: TaskDataSpec,
tokenizer,
max_seq_length: int,
idx: int,
) -> DatumSpec:
```
We have an example of this as `math_data_processor` in [run_grpo_math.py](../../examples/run_grpo_math.py)

#### Putting it all together:
GRPO expects datasets to have the following form:
```json
{"task_name": "math", <actual data>}
```
Then, you can set data up as such:
```python
base_dataset = load_dataset("json", data_files=data_config["dataset_name"])["train"]
tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"])

task_data_processors = defaultdict(lambda: (math_task_spec, math_data_processor))
task_data_processors["math"] = (math_task_spec, math_data_processor)

math_env = MathEnvironment.remote(env_configs["math"]) # ray remote actor

dataset = AllTaskProcessedDataset(
base_dataset,
tokenizer,
math_task_spec,
task_data_processors,
max_seq_length=data_config["max_input_seq_length"],
)
```
Notice that you provide a mapping of tasks to their processors so the dataset knows what to use when processing samples.


### Policy Model
We define a [PolicyInterface]() that contains everything you need to train a Policy model.

This Policy object holds a [RayWorkerGroup](../../nemo_reinforcer/distributed/worker_groups.py) of SPMD (1 proc/gpu) processes that run HF/MCore, all coordinated by this object so it appears to you like 1 GPU!

### Fast Generation
We support vLLM through the [VllmGeneration](../../nemo_reinforcer/models/generation/vllm.py) class right now.

The function [grpo_train](../../nemo_reinforcer/algorithms/grpo.py) contains the core GRPO training loop.
2 changes: 1 addition & 1 deletion docs/guides/sft.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The script [examples/run_sft.py](../../examples/run_sft.py) can be used to launc

Be sure to launch the job using `uv`. The command to launch an SFT job is as follows:
```bash
uv run examples/run_sft.py --config <PATH TO YAML CONFIG>
uv run examples/run_sft.py --config <PATH TO YAML CONFIG> <OVERRIDES>
```
If not specified, `config` will default to [examples/configs/sft.yaml](../../examples/configs/sft.yaml).

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# GRPO Algorithm Configuration
defaults: "base.yaml"

grpo:
num_prompts_per_step: 8
num_generations_per_prompt: 8
num_steps: 100
max_num_steps: 100
normalize_rewards: true
use_leave_one_out_baseline: true
val_period: 10
Expand All @@ -29,16 +27,28 @@ policy:
train_global_batch_size: 32
train_micro_batch_size: 4
generation_batch_size: 32
learning_rate: 5.0e-6
logprob_batch_size: 4
max_total_sequence_length: 1024
max_total_sequence_length: 512

scheduler:
- name: "torch.optim.lr_scheduler.LinearLR"
kwargs:
start_factor: 0.1
end_factor: 1.0
total_iters: 50
- name: "torch.optim.lr_scheduler.ConstantLR"
kwargs:
factor: 1.0
total_iters: 10000000000
- milestones: [50]

generation:
backend: "vllm" # "vllm" or "hf"(to use the hf training framework's generation)
max_new_tokens: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len below
backend: "vllm"
max_new_tokens: ${policy.max_total_sequence_length}
temperature: 1.0
# Don't change since vllm logprobs in V0 runtime are after sampling and in V1 runtime are before sampling.
top_p: 1.0
top_k: null # disable
top_k: null
vllm_cfg:
tensor_parallel_size: 1
gpu_memory_utilization: 0.7
Expand All @@ -54,3 +64,17 @@ data:
env:
math:
num_workers: 8

logger:
log_dir: "logs" # Base directory for all logs
num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal
wandb_enabled: false
tensorboard_enabled: false
wandb:
project: "grpo-dev"
name: "grpo-dev-logger"
tensorboard: {}

cluster:
gpus_per_node: 1
num_nodes: 1
33 changes: 10 additions & 23 deletions examples/configs/base.yaml → examples/configs/grpo_math_8B.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Base configuration with common settings
# GRPO Algorithm Configuration
defaults: "grpo_math_1B.yaml"

policy:
model_name: "meta-llama/Llama-3.2-1B-Instruct"
model_name: "meta-llama/Llama-3.1-8B-Instruct"
train_global_batch_size: 32
train_micro_batch_size: 4
train_micro_batch_size: 1
generation_batch_size: 32
learning_rate: 5.0e-6
logprob_batch_size: 4
max_total_sequence_length: 8192
logprob_batch_size: 2
max_total_sequence_length: 4096

scheduler:
- name: "torch.optim.lr_scheduler.LinearLR"
Expand All @@ -28,24 +30,9 @@ policy:
top_k: null
vllm_cfg:
tensor_parallel_size: 1
gpu_memory_utilization: 0.7
gpu_memory_utilization: 0.6
max_model_len: ${policy.max_total_sequence_length}

data:
max_input_seq_length: ${policy.max_total_sequence_length}
prompt_file: "examples/prompts/cot.txt"
system_prompt_file: null

logger:
log_dir: "logs" # Base directory for all logs
num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal
wandb_enabled: false
tensorboard_enabled: false
wandb:
project: "grpo-dev"
name: "grpo-dev-logger"
tensorboard: {}


cluster:
gpus_per_node: 2
gpus_per_node: 8
num_nodes: 1
2 changes: 1 addition & 1 deletion examples/configs/sft.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SFT Algorithm Configuration
sft:
num_steps: 20
max_num_steps: 20
val_period: 10
val_batches: 8
val_global_batch_size: 32
Expand Down
2 changes: 1 addition & 1 deletion examples/run_grpo_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def main():
args, overrides = parse_args()

if not args.config:
args.config = os.path.join(os.path.dirname(__file__), "configs", "grpo.yaml")
args.config = os.path.join(os.path.dirname(__file__), "configs", "grpo_math_1B.yaml")

config = load_config(args.config)
print(f"Loaded configuration from: {args.config}")
Expand Down
6 changes: 3 additions & 3 deletions nemo_reinforcer/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
class GRPOConfig(TypedDict):
num_prompts_per_step: int
num_generations_per_prompt: int
num_steps: int
max_num_steps: int
normalize_rewards: bool
use_leave_one_out_baseline: bool
val_period: int
Expand Down Expand Up @@ -445,7 +445,7 @@ def grpo_train(

# Run grpo training (single-turn)
for batch in dataloader:
print(f"\n{'=' * 25} Step {step + 1}/{len(dataloader)} {'=' * 25}")
print(f"\n{'=' * 25} Step {step + 1}/{min(len(dataloader), master_config['grpo']['max_num_steps'])} {'=' * 25}")

with timer.time("total_step_time"):
# Prepare batch
Expand Down Expand Up @@ -654,7 +654,7 @@ def grpo_train(

timer.reset()
step += 1
if step >= master_config["grpo"]["num_steps"]:
if step >= master_config["grpo"]["max_num_steps"]:
break


Expand Down
4 changes: 2 additions & 2 deletions nemo_reinforcer/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _default_sft_save_state() -> SFTSaveState:


class SFTConfig(TypedDict):
num_steps: int
max_num_steps: int
val_period: int
val_batches: int
val_global_batch_size: int
Expand Down Expand Up @@ -437,5 +437,5 @@ def sft_train(
timer.reset()
step += 1

if step >= master_config["sft"]["num_steps"]:
if step >= master_config["sft"]["max_num_steps"]:
break
2 changes: 1 addition & 1 deletion tests/functional/grpo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ mkdir -p $LOG_DIR
cd $PROJECT_ROOT
uv run $PROJECT_ROOT/examples/run_grpo_math.py \
cluster.gpus_per_node=2 \
grpo.num_steps=10 \
grpo.max_num_steps=10 \
logger.tensorboard_enabled=true \
logger.log_dir=$LOG_DIR \
logger.wandb_enabled=false \
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ mkdir -p $LOG_DIR
cd $PROJECT_ROOT
uv run $PROJECT_ROOT/examples/run_sft.py \
cluster.gpus_per_node=2 \
sft.num_steps=10 \
sft.max_num_steps=10 \
logger.tensorboard_enabled=true \
logger.log_dir=$LOG_DIR \
logger.wandb_enabled=false \
Expand Down