diff --git a/README.md b/README.md index 4668a90ede..edff7aba42 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ uv pip install -e '.[dev,test]' # Use uv run to launch any runs. # Note that it is recommended to not activate the venv and instead use `uv run` since # it ensures consistent environment usage across different shells and sessions. -uv run python examples/run_grpo.py +uv run python examples/run_grpo_math.py ``` ## Cluster Start diff --git a/docs/assets/actor-wg-worker-vc.png b/docs/assets/actor-wg-worker-vc.png new file mode 100644 index 0000000000..fe360c9939 Binary files /dev/null and b/docs/assets/actor-wg-worker-vc.png differ diff --git a/docs/design_docs/design_and_philosophy.md b/docs/design_docs/design_and_philosophy.md index ba1c2e28c9..e9fead87e8 100644 --- a/docs/design_docs/design_and_philosophy.md +++ b/docs/design_docs/design_and_philosophy.md @@ -9,8 +9,6 @@ Online RL requires coordinating a lot of different pieces of software/models We refer to each of these pieces of software as an **RL Actor**. -[TODO @sahilj Diagram] - Fundamentally, we need to be able to do 4 things between these RL Actors: - Resource them (provide GPUs/CPUs) - Isolate them @@ -32,6 +30,8 @@ We create composable and hackable abstractions for each layer of the tasks above By creating a common interface for these 4 tasks, **RL algorithm code looks the same from 1 GPU to 1000 GPUs and does not care about the implementation of each RL Actor (Megatron, HF, Grad student with pen and paper)** +![actor-wg-worker-vc](../assets/actor-wg-worker-vc.png) + ### {py:class}`RayVirtualCluster ` VirtualCluster provides a basic abstraction on top of Ray Placement Groups that allow you to section off a part of your compute resources for WorkerGroups to run on as though they had their own cluster. They support running just one WorkerGroup on each VirtualCluster, or *colocation*, where multiple WorkerGroups share resources (i.e running policy training(hf) and generation(vllm) on the same GPUs in-turn). @@ -84,12 +84,29 @@ class RayWorkerGroup: - Support for tied worker groups where multiple workers process the same data """ ``` -[TODO @sahilj Diagram] - +`RayWorkerGroup` provides functions like `run_all_workers_single_data` and `run_all_workers_multiple_data` to control and communicate to individual worker processes. ### Single-Controller & Execution Diagram - -## Walking through an implementation of GRPO - - +We control the RL Actors using a single-process head controller. Using the aforementioned abstractions, this allows us to represent the main loop of GRPO as though we were working on 1 GPU +```python +# data processing/transformations between each step omitted +def grpo_train( + policy: PolicyInterface, + policy_generation: GenerationInterface, + environment: EnvironmentInterface, + dataloader: Iterable[BatchedDataDict[DatumSpec]], +): + loss_fn = GRPOLossFn() + for batch in dataloader: + batch.repeat_interleave(num_generations_per_prompt) # repeat for GRPO + generations = policy_generation.generate(batch) + rewards = environment.step(generations) + + logprobs = policy.get_logprobs(generations) + reference_logprobs = policy.get_reference_logprobs(generations) + + training_data = calculate_grpo_trainnig_data(generations, logprobs, reference_logprobs, rewards) + policy.train(generations, logprobs, reference_logprobs, GRPOLossFn) +``` +For a real implementation of grpo (with valiation, checkpointing, memory movement, and the omitted data processing steps), see [grpo_train](../../nemo_reinforcer/algorithms/grpo.py) diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 0925548d0d..8d93fb64f7 100644 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -1,3 +1,93 @@ -# GRPO +# An in-depth walkthrough of GRPO in Reinforcer -placeholder TBD +## Quickstart: Launch a GRPO Run + +If you want to get running quickly, the script [examples/run_grpo_math.py](../../examples/run_grpo_math.py) has an example implementation of using GRPO to train a model on math problems. This script can either be launched locally or via Slurm. For details on how to set up Ray and launch a job using Slurm, refer to the [cluster documentation](../cluster.md). + +We recommend launching the job using `uv`: +```bash +uv run examples/run_grpo_math.py --config {overrides} +``` +If not specified, `config` will default to [examples/configs/grpo.yaml](../../examples/configs/grpo.yaml) + +## Now, for the details: + +In this guide, we'll walk through we handle +* Data +* Model training +* Fast generation +* Overall Resource Flow + +### Data +We support training with multiple RL "Environments" at the same time. + +An [Environment](../../nemo_reinforcer/environments/interfaces.py) is an object that accepts a state/action history and returns an update state and rewards for the step. They run as Ray Remote Actors. Example [MathEnvironment](../../nemo_reinforcer/environments/math_environment.py). + +To support this, we need to know: +* What environments you have +* Which data should go to which environments +* How to prepare the data from your dataset into a form we can use + +#### Common Data Format +We define a [DatumSpec](../../nemo_reinforcer/data/interfaces.py) that holds all relevant information for each training example: +```python +class DatumSpec(TypedDict): + message_log: LLMMessageLogType + length: int # total (concatenated) length of the message tensors + extra_env_info: Dict[str, Any] # anything your environment requires goes here, for example the 'answer' of a math problem + loss_multiplier: float # multiplier for the loss for this datum. 0 to mask out (say the sample is invalid) + idx: int + task_name: Optional[str] = "default" + __extra__: Any # This allows additional fields of any type +``` + +#### Data Processors +We name all distinct "environments your model wants to optimize against" "tasks". So you might define a "math" task or a "code" task. +For each task, you should provide a data processor that reads from your dataset and returns a [DatumSpec](../../nemo_reinforcer/data/interfaces.py) + +```python +def my_data_processor( + datum_dict: Dict[str, Any], # loaded directly from your dataset (i.e. single line of jsonl data) + task_data_spec: TaskDataSpec, + tokenizer, + max_seq_length: int, + idx: int, +) -> DatumSpec: +``` +We have an example of this as `math_data_processor` in [run_grpo_math.py](../../examples/run_grpo_math.py) + +#### Putting it all together: +GRPO expects datasets to have the following form: +```json +{"task_name": "math", } +``` +Then, you can set data up as such: +```python +base_dataset = load_dataset("json", data_files=data_config["dataset_name"])["train"] +tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"]) + +task_data_processors = defaultdict(lambda: (math_task_spec, math_data_processor)) +task_data_processors["math"] = (math_task_spec, math_data_processor) + +math_env = MathEnvironment.remote(env_configs["math"]) # ray remote actor + +dataset = AllTaskProcessedDataset( + base_dataset, + tokenizer, + math_task_spec, + task_data_processors, + max_seq_length=data_config["max_input_seq_length"], +) +``` +Notice that you provide a mapping of tasks to their processors so the dataset knows what to use when processing samples. + + +### Policy Model +We define a [PolicyInterface]() that contains everything you need to train a Policy model. + +This Policy object holds a [RayWorkerGroup](../../nemo_reinforcer/distributed/worker_groups.py) of SPMD (1 proc/gpu) processes that run HF/MCore, all coordinated by this object so it appears to you like 1 GPU! + +### Fast Generation +We support vLLM through the [VllmGeneration](../../nemo_reinforcer/models/generation/vllm.py) class right now. + +The function [grpo_train](../../nemo_reinforcer/algorithms/grpo.py) contains the core GRPO training loop. \ No newline at end of file diff --git a/docs/guides/sft.md b/docs/guides/sft.md index 1ae95dad1f..d994967bd2 100644 --- a/docs/guides/sft.md +++ b/docs/guides/sft.md @@ -6,7 +6,7 @@ The script [examples/run_sft.py](../../examples/run_sft.py) can be used to launc Be sure to launch the job using `uv`. The command to launch an SFT job is as follows: ```bash -uv run examples/run_sft.py --config +uv run examples/run_sft.py --config ``` If not specified, `config` will default to [examples/configs/sft.yaml](../../examples/configs/sft.yaml). diff --git a/examples/configs/grpo.yaml b/examples/configs/grpo_math_1B.yaml similarity index 60% rename from examples/configs/grpo.yaml rename to examples/configs/grpo_math_1B.yaml index 4436421009..684e24e7f5 100644 --- a/examples/configs/grpo.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -1,10 +1,8 @@ # GRPO Algorithm Configuration -defaults: "base.yaml" - grpo: num_prompts_per_step: 8 num_generations_per_prompt: 8 - num_steps: 100 + max_num_steps: 100 normalize_rewards: true use_leave_one_out_baseline: true val_period: 10 @@ -29,16 +27,28 @@ policy: train_global_batch_size: 32 train_micro_batch_size: 4 generation_batch_size: 32 + learning_rate: 5.0e-6 logprob_batch_size: 4 - max_total_sequence_length: 1024 + max_total_sequence_length: 512 + + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.1 + end_factor: 1.0 + total_iters: 50 + - name: "torch.optim.lr_scheduler.ConstantLR" + kwargs: + factor: 1.0 + total_iters: 10000000000 + - milestones: [50] generation: - backend: "vllm" # "vllm" or "hf"(to use the hf training framework's generation) - max_new_tokens: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len below + backend: "vllm" + max_new_tokens: ${policy.max_total_sequence_length} temperature: 1.0 - # Don't change since vllm logprobs in V0 runtime are after sampling and in V1 runtime are before sampling. top_p: 1.0 - top_k: null # disable + top_k: null vllm_cfg: tensor_parallel_size: 1 gpu_memory_utilization: 0.7 @@ -54,3 +64,17 @@ data: env: math: num_workers: 8 + +logger: + log_dir: "logs" # Base directory for all logs + num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal + wandb_enabled: false + tensorboard_enabled: false + wandb: + project: "grpo-dev" + name: "grpo-dev-logger" + tensorboard: {} + +cluster: + gpus_per_node: 1 + num_nodes: 1 \ No newline at end of file diff --git a/examples/configs/base.yaml b/examples/configs/grpo_math_8B.yaml similarity index 50% rename from examples/configs/base.yaml rename to examples/configs/grpo_math_8B.yaml index ded5e8255c..2415ee8cd8 100644 --- a/examples/configs/base.yaml +++ b/examples/configs/grpo_math_8B.yaml @@ -1,12 +1,14 @@ -# Base configuration with common settings +# GRPO Algorithm Configuration +defaults: "grpo_math_1B.yaml" + policy: - model_name: "meta-llama/Llama-3.2-1B-Instruct" + model_name: "meta-llama/Llama-3.1-8B-Instruct" train_global_batch_size: 32 - train_micro_batch_size: 4 + train_micro_batch_size: 1 generation_batch_size: 32 learning_rate: 5.0e-6 - logprob_batch_size: 4 - max_total_sequence_length: 8192 + logprob_batch_size: 2 + max_total_sequence_length: 4096 scheduler: - name: "torch.optim.lr_scheduler.LinearLR" @@ -28,24 +30,9 @@ policy: top_k: null vllm_cfg: tensor_parallel_size: 1 - gpu_memory_utilization: 0.7 + gpu_memory_utilization: 0.6 max_model_len: ${policy.max_total_sequence_length} - -data: - max_input_seq_length: ${policy.max_total_sequence_length} - prompt_file: "examples/prompts/cot.txt" - system_prompt_file: null - -logger: - log_dir: "logs" # Base directory for all logs - num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal - wandb_enabled: false - tensorboard_enabled: false - wandb: - project: "grpo-dev" - name: "grpo-dev-logger" - tensorboard: {} - + cluster: - gpus_per_node: 2 + gpus_per_node: 8 num_nodes: 1 \ No newline at end of file diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index b938a9d321..38db492017 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -1,6 +1,6 @@ # SFT Algorithm Configuration sft: - num_steps: 20 + max_num_steps: 20 val_period: 10 val_batches: 8 val_global_batch_size: 32 diff --git a/examples/run_grpo_math.py b/examples/run_grpo_math.py index b3fadbb563..4e7732def1 100644 --- a/examples/run_grpo_math.py +++ b/examples/run_grpo_math.py @@ -170,7 +170,7 @@ def main(): args, overrides = parse_args() if not args.config: - args.config = os.path.join(os.path.dirname(__file__), "configs", "grpo.yaml") + args.config = os.path.join(os.path.dirname(__file__), "configs", "grpo_math_1B.yaml") config = load_config(args.config) print(f"Loaded configuration from: {args.config}") diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index 0f120330fa..94b9f29a2b 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -69,7 +69,7 @@ class GRPOConfig(TypedDict): num_prompts_per_step: int num_generations_per_prompt: int - num_steps: int + max_num_steps: int normalize_rewards: bool use_leave_one_out_baseline: bool val_period: int @@ -445,7 +445,7 @@ def grpo_train( # Run grpo training (single-turn) for batch in dataloader: - print(f"\n{'=' * 25} Step {step + 1}/{len(dataloader)} {'=' * 25}") + print(f"\n{'=' * 25} Step {step + 1}/{min(len(dataloader), master_config['grpo']['max_num_steps'])} {'=' * 25}") with timer.time("total_step_time"): # Prepare batch @@ -654,7 +654,7 @@ def grpo_train( timer.reset() step += 1 - if step >= master_config["grpo"]["num_steps"]: + if step >= master_config["grpo"]["max_num_steps"]: break diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index 402b5bad92..a8d065d37d 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -51,7 +51,7 @@ def _default_sft_save_state() -> SFTSaveState: class SFTConfig(TypedDict): - num_steps: int + max_num_steps: int val_period: int val_batches: int val_global_batch_size: int @@ -437,5 +437,5 @@ def sft_train( timer.reset() step += 1 - if step >= master_config["sft"]["num_steps"]: + if step >= master_config["sft"]["max_num_steps"]: break diff --git a/tests/functional/grpo.sh b/tests/functional/grpo.sh index 56ea95026d..cfbc1bc712 100755 --- a/tests/functional/grpo.sh +++ b/tests/functional/grpo.sh @@ -16,7 +16,7 @@ mkdir -p $LOG_DIR cd $PROJECT_ROOT uv run $PROJECT_ROOT/examples/run_grpo_math.py \ cluster.gpus_per_node=2 \ - grpo.num_steps=10 \ + grpo.max_num_steps=10 \ logger.tensorboard_enabled=true \ logger.log_dir=$LOG_DIR \ logger.wandb_enabled=false \ diff --git a/tests/functional/sft.sh b/tests/functional/sft.sh index 94d565d700..0e2298c983 100755 --- a/tests/functional/sft.sh +++ b/tests/functional/sft.sh @@ -16,7 +16,7 @@ mkdir -p $LOG_DIR cd $PROJECT_ROOT uv run $PROJECT_ROOT/examples/run_sft.py \ cluster.gpus_per_node=2 \ - sft.num_steps=10 \ + sft.max_num_steps=10 \ logger.tensorboard_enabled=true \ logger.log_dir=$LOG_DIR \ logger.wandb_enabled=false \