Skip to content
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
7f0a931
WIP init commit
parthchadha Jul 23, 2025
01e8887
Add checks to prevent cuda oom
parthchadha Aug 2, 2025
15364fc
Merge remote-tracking branch 'origin/main' into pchadha/async-basic
parthchadha Aug 4, 2025
c1160d5
Merge remote-tracking branch 'origin/main' into pchadha/async-basic
parthchadha Aug 5, 2025
f66f4c0
Default max_trajectory_age_steps=1 by default
parthchadha Aug 5, 2025
509312a
debug implementation
parthchadha Aug 13, 2025
621a1fa
fix lock for writing into dict
parthchadha Aug 13, 2025
c196790
Fix stalls
parthchadha Aug 13, 2025
653e4f5
More fixes
parthchadha Aug 13, 2025
718c332
Add stronger check for stall
parthchadha Aug 15, 2025
722fae8
Fix more stalling issues and wrong batch data use
parthchadha Aug 15, 2025
160a350
Fix incorrect clearning of inflight generation targets
parthchadha Aug 19, 2025
01f1cd6
Add stall on refit and log avg age of samples
parthchadha Aug 20, 2025
0b7c8f9
Save the state of dataloader from the collector
parthchadha Aug 22, 2025
987bdfe
Merge remote-tracking branch 'origin/main' into faster-strictfifo
parthchadha Aug 26, 2025
ff52010
Fix incorrect passing of gbs args which reduced async experiments to …
parthchadha Aug 29, 2025
81fc3a5
Add assertion when async grpo is used with sync vllm engine
parthchadha Aug 29, 2025
7ca538d
fix: Decouple exposed_generation time from weight_sync time (#1052)
youngeunkwon0405 Sep 3, 2025
a1f35cb
fix: issue where generation_weight_version isn't correct after refit …
RahulSChand Sep 3, 2025
ca84567
Merge remote-tracking branch 'origin/main' into faster-strictfifo
parthchadha Sep 3, 2025
07bcd9a
Merge remote-tracking branch 'origin/faster-strictfifo' into faster-s…
parthchadha Sep 3, 2025
b321181
Merge remote-tracking branch 'origin/main' into faster-strictfifo
parthchadha Sep 8, 2025
576634c
Add more detailed comments about async
parthchadha Sep 8, 2025
6868407
Add more comments; resolve review feedback
parthchadha Sep 9, 2025
c8ee01f
Merge remote-tracking branch 'origin/main' into faster-strictfifo
parthchadha Sep 15, 2025
62eec29
Move async config to grpo/, remove async config examples, clean up code
parthchadha Sep 15, 2025
d7e1e29
Add async grpo docs
parthchadha Sep 17, 2025
4ace692
Add functional L1 test
parthchadha Sep 18, 2025
215e7a9
Merge remote-tracking branch 'origin/main' into faster-strictfifo
parthchadha Sep 18, 2025
7f90b60
Update configs with async_grpo flag
parthchadha Sep 18, 2025
79944bf
Add diagram in docs
parthchadha Sep 18, 2025
c1bbe8b
Apply suggestions from code review
parthchadha Sep 18, 2025
c0a144a
Merge remote-tracking branch 'origin/main' into faster-strictfifo
parthchadha Sep 18, 2025
ad42ecf
Merge remote-tracking branch 'origin/main' into faster-strictfifo
parthchadha Sep 18, 2025
3a7ca57
fix doc failure
parthchadha Sep 18, 2025
bf5a6d7
feat: async RL
terrykong Sep 19, 2025
b534850
Add missing async unit tests
parthchadha Sep 19, 2025
a7f12d5
Merge remote-tracking branch 'origin/faster-strictfifo' into faster-s…
parthchadha Sep 19, 2025
8eb7c93
Add missing grpo config in vlm yaml
parthchadha Sep 22, 2025
2c7a922
Merge remote-tracking branch 'origin/main' into faster-strictfifo
parthchadha Sep 22, 2025
7ed8a90
Raise error if ReplayBuffer created with <= 0 size
parthchadha Sep 22, 2025
c10f034
Add missing pragma no cover to ray remote async class
parthchadha Sep 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions examples/configs/async_grpo_math_1B.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Async GRPO Algorithm Configuration
defaults: "grpo_math_1B.yaml"

# Async-specific settings
async_grpo:
enabled: true # Enable async training
max_trajectory_age_steps: 1 # Allow trajectories from the last 4 training steps

grpo:
num_prompts_per_step: 32
num_generations_per_prompt: 16
max_rollout_turns: 1
max_num_steps: 1000
normalize_rewards: true
use_leave_one_out_baseline: true
val_period: 10
val_at_start: false
max_val_samples: 256
val_batch_size: 256

policy:
model_name: "Qwen/Qwen2.5-1.5B"
tokenizer:
name: ${policy.model_name}
train_global_batch_size: 512
train_micro_batch_size: 4
generation_batch_size: 32
logprob_batch_size: 4
max_total_sequence_length: 512
precision: "bfloat16"
fsdp_offload_enabled: false
activation_checkpointing_enabled: false

dtensor_cfg:
enabled: true
cpu_offload: False
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1
context_parallel_size: 1
custom_parallel_plan: null

dynamic_batching:
enabled: True

sequence_packing:
enabled: False

generation:
backend: "vllm"
max_new_tokens: ${policy.max_total_sequence_length}
temperature: 1.0
top_p: 1.0
top_k: null
stop_token_ids: null
stop_strings: null
vllm_cfg:
# Enable async engine for better concurrency
async_engine: true
tensor_parallel_size: 1
gpu_memory_utilization: 0.8
max_model_len: ${policy.max_total_sequence_length}
enforce_eager: False
colocated:
enabled: false
resources:
gpus_per_node: 1
num_nodes: 1

# Data configuration
data:
dataset_name: "OpenMathInstruct-2"
subset: "train"
limit: 1000 # For faster testing
val_limit: 100
val_subset: "test"
max_input_seq_length: ${policy.max_total_sequence_length}
prompt_file: "examples/prompts/cot.txt"
system_prompt_file: null
add_system_prompt: false

# Environment configuration
env:
math:
cfg:
num_workers: 4

checkpointing:
enabled: true
checkpoint_dir: "results/async_grpo_importance_sampling"
metric_name: "val_reward"
higher_is_better: true
keep_top_k: 3
save_period: 10

logger:
log_dir: "logs"
num_val_samples_to_print: 3
wandb_enabled: true
tensorboard_enabled: false
monitor_gpus: true
wandb:
project: "async-grpo-dev"
name: "async-grpo-math"
tensorboard: {}
gpu_monitoring:
collection_interval: 10
flush_interval: 10

98 changes: 98 additions & 0 deletions examples/configs/async_grpo_math_8B.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Async GRPO Algorithm Configuration
defaults: "grpo_math_1B.yaml"

# Async-specific settings
async_grpo:
enabled: true # Enable async training
max_trajectory_age_steps: 1 # Allow trajectories from the last 1 training steps

grpo:
num_prompts_per_step: 64
num_generations_per_prompt: 32
Comment thread
terrykong marked this conversation as resolved.
Outdated

# Use async-friendly backend
policy:
model_name: "meta-llama/Llama-3.1-8B-Instruct"
tokenizer:
name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
train_global_batch_size: 512
train_micro_batch_size: 1
generation_batch_size: 32 # Only used when generating using HF backend
logprob_batch_size: 2
max_total_sequence_length: 4096
precision: "bfloat16"
fsdp_offload_enabled: false
activation_checkpointing_enabled: false

dtensor_cfg:
enabled: true
cpu_offload: False
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1
context_parallel_size: 1
custom_parallel_plan: null

dynamic_batching:
enabled: False

optimizer:
name: "torch.optim.AdamW"
kwargs:
lr: 3.0e-7
weight_decay: 0.01
betas: [0.9, 0.999]
eps: 1e-8

scheduler:
- name: "torch.optim.lr_scheduler.LinearLR"
kwargs:
start_factor: 0.1
end_factor: 1.0
# The scheduler iteration is per GPRO step and is decoupled with the optimizer step (may be >=1 per GPRO step)
total_iters: 13
- name: "torch.optim.lr_scheduler.ConstantLR"
kwargs:
factor: 1.0
total_iters: 10000000000
- milestones: [13]

Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
generation:
backend: "vllm"
max_new_tokens: ${policy.max_total_sequence_length}
temperature: 1.0
top_p: 1.0
top_k: null
stop_token_ids: null
stop_strings: null
vllm_cfg:
# Enable async engine for better concurrency
async_engine: true
tensor_parallel_size: 1
gpu_memory_utilization: 0.8
max_model_len: ${policy.max_total_sequence_length}
enforce_eager: False
colocated:
enabled: false
resources:
gpus_per_node: 1
num_nodes: 1

cluster:
gpus_per_node: 8
num_nodes: 1

logger:
log_dir: "logs"
num_val_samples_to_print: 3
wandb_enabled: true
tensorboard_enabled: false
monitor_gpus: true
wandb:
project: "async-grpo-dev"
name: "async-grpo-math"
tensorboard: {}
gpu_monitoring:
collection_interval: 10
flush_interval: 10

55 changes: 41 additions & 14 deletions examples/run_grpo_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,20 +252,47 @@ def main() -> None:
master_config,
) = setup(config, tokenizer, dataset, val_dataset)

grpo_train(
policy,
policy_generation,
dataloader,
val_dataloader,
tokenizer,
loss_fn,
task_to_env,
val_task_to_env,
logger,
checkpointer,
grpo_state,
master_config,
)
# Check if async mode is enabled
async_config = config.get("async_grpo", {})
if async_config and async_config.get("enabled", False):
from nemo_rl.algorithms.grpo import async_grpo_train

print("🚀 Running async GRPO training")

# Run async GRPO training
async_grpo_train(
policy=policy,
policy_generation=policy_generation,
dataloader=dataloader,
val_dataloader=val_dataloader,
tokenizer=tokenizer,
loss_fn=loss_fn,
task_to_env=task_to_env,
val_task_to_env=val_task_to_env,
logger=logger,
checkpointer=checkpointer,
grpo_save_state=grpo_state,
master_config=master_config,
max_trajectory_age_steps=async_config.get("max_trajectory_age_steps", 1),
)
else:
print("🚀 Running synchronous GRPO training")

# Run standard GRPO training
grpo_train(
policy,
policy_generation,
dataloader,
val_dataloader,
tokenizer,
loss_fn,
task_to_env,
val_task_to_env,
logger,
checkpointer,
grpo_state,
master_config,
)


if __name__ == "__main__":
Expand Down
Loading