Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
7f0a931
WIP init commit
parthchadha Jul 23, 2025
01e8887
Add checks to prevent cuda oom
parthchadha Aug 2, 2025
15364fc
Merge remote-tracking branch 'origin/main' into pchadha/async-basic
parthchadha Aug 4, 2025
c1160d5
Merge remote-tracking branch 'origin/main' into pchadha/async-basic
parthchadha Aug 5, 2025
f66f4c0
Default max_trajectory_age_steps=1 by default
parthchadha Aug 5, 2025
509312a
debug implementation
parthchadha Aug 13, 2025
621a1fa
fix lock for writing into dict
parthchadha Aug 13, 2025
c196790
Fix stalls
parthchadha Aug 13, 2025
653e4f5
More fixes
parthchadha Aug 13, 2025
718c332
Add stronger check for stall
parthchadha Aug 15, 2025
722fae8
Fix more stalling issues and wrong batch data use
parthchadha Aug 15, 2025
160a350
Fix incorrect clearning of inflight generation targets
parthchadha Aug 19, 2025
01f1cd6
Add stall on refit and log avg age of samples
parthchadha Aug 20, 2025
0b7c8f9
Save the state of dataloader from the collector
parthchadha Aug 22, 2025
987bdfe
Merge remote-tracking branch 'origin/main' into faster-strictfifo
parthchadha Aug 26, 2025
ff52010
Fix incorrect passing of gbs args which reduced async experiments to …
parthchadha Aug 29, 2025
81fc3a5
Add assertion when async grpo is used with sync vllm engine
parthchadha Aug 29, 2025
7ca538d
fix: Decouple exposed_generation time from weight_sync time (#1052)
youngeunkwon0405 Sep 3, 2025
a1f35cb
fix: issue where generation_weight_version isn't correct after refit …
RahulSChand Sep 3, 2025
ca84567
Merge remote-tracking branch 'origin/main' into faster-strictfifo
parthchadha Sep 3, 2025
07bcd9a
Merge remote-tracking branch 'origin/faster-strictfifo' into faster-s…
parthchadha Sep 3, 2025
b321181
Merge remote-tracking branch 'origin/main' into faster-strictfifo
parthchadha Sep 8, 2025
576634c
Add more detailed comments about async
parthchadha Sep 8, 2025
6868407
Add more comments; resolve review feedback
parthchadha Sep 9, 2025
c8ee01f
Merge remote-tracking branch 'origin/main' into faster-strictfifo
parthchadha Sep 15, 2025
62eec29
Move async config to grpo/, remove async config examples, clean up code
parthchadha Sep 15, 2025
d7e1e29
Add async grpo docs
parthchadha Sep 17, 2025
4ace692
Add functional L1 test
parthchadha Sep 18, 2025
215e7a9
Merge remote-tracking branch 'origin/main' into faster-strictfifo
parthchadha Sep 18, 2025
7f90b60
Update configs with async_grpo flag
parthchadha Sep 18, 2025
79944bf
Add diagram in docs
parthchadha Sep 18, 2025
c1bbe8b
Apply suggestions from code review
parthchadha Sep 18, 2025
c0a144a
Merge remote-tracking branch 'origin/main' into faster-strictfifo
parthchadha Sep 18, 2025
ad42ecf
Merge remote-tracking branch 'origin/main' into faster-strictfifo
parthchadha Sep 18, 2025
3a7ca57
fix doc failure
parthchadha Sep 18, 2025
bf5a6d7
feat: async RL
terrykong Sep 19, 2025
b534850
Add missing async unit tests
parthchadha Sep 19, 2025
a7f12d5
Merge remote-tracking branch 'origin/faster-strictfifo' into faster-s…
parthchadha Sep 19, 2025
8eb7c93
Add missing grpo config in vlm yaml
parthchadha Sep 22, 2025
2c7a922
Merge remote-tracking branch 'origin/main' into faster-strictfifo
parthchadha Sep 22, 2025
7ed8a90
Raise error if ReplayBuffer created with <= 0 size
parthchadha Sep 22, 2025
c10f034
Add missing pragma no cover to ray remote async class
parthchadha Sep 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions docs/guides/async-grpo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Train with Async GRPO

Async GRPO is an asynchronous training mode that allows trajectory generation and policy training to run concurrently, improving GPU utilization and throughput compared to synchronous GRPO.

## Configure Async GRPO

This section covers how to configure async GRPO by modifying your settings and includes a complete example configuration.
### Enable Async GRPO

To use async GRPO, make these configuration changes:

1. **Enable vLLM async engine**:
```yaml
policy:
generation:
backend: "vllm"
vllm_cfg:
async_engine: true
```

2. **Enable importance sampling correction** (required for convergence):
```yaml
loss_fn:
use_importance_sampling_correction: true
```

3. **Disable colocated inference** (required for async mode):
```yaml
policy:
generation:
colocated:
enabled: false
resources:
num_nodes: 1 # or more
gpus_per_node: 2 # adjust based on your setup
```

4. **Add async GRPO configuration**:
```yaml
grpo:
async_grpo:
max_trajectory_age_steps: 1 # Maximum age, in training steps, for trajectories
```

### Complete Example Config
```yaml
policy:
generation:
backend: "vllm"
colocated:
enabled: false
resources:
num_nodes: 1
gpus_per_node: 2
vllm_cfg:
async_engine: true

loss_fn:
use_importance_sampling_correction: true

grpo:
num_prompts_per_step: 32
num_generations_per_prompt: 4
async_grpo:
max_trajectory_age_steps: 1

cluster:
num_nodes: 2
gpus_per_node: 4
```

## Implementation Structure
This section covers the internal architecture of async GRPO and includes detailed explanations of how the core components interact.
### Core Components

The async GRPO implementation consists of three main components:

#### 1. Main Training Loop (`async_grpo_train` in `grpo.py`)
- Coordinates overall training process
- Samples trajectories from replay buffer
- Runs policy training steps
- Handles validation and checkpointing
- Manages weight synchronization between training and generation

#### 2. Async Trajectory Collector (`AsyncTrajectoryCollector` in `async_utils.py`)
- Runs in background Ray actor
- Continuously generates trajectories using current policy weights
- Manages generation scheduling and weight version tracking
- Handles pause/resume for weight updates and validation
- Coordinates with replay buffer for trajectory storage

#### 3. Replay Buffer (`ReplayBuffer` in `async_utils.py`)
- Stores generated trajectories with metadata
- Tracks weight versions for both generation and intended training use
- Implements age-based filtering to prevent stale trajectories
- Provides sampling interface for training steps

### Weight Version Tracking

Async GRPO uses a weight versioning system:
- **Generation Weight Version**: The policy weights used to generate a trajectory
- **Target Weight Version**: The training step where the trajectory will be used
- **Max Trajectory Age**: How many steps old a trajectory can be before being discarded

Example with `max_trajectory_age_steps: 1`:
- Trajectory generated with weights v10 can be used for training steps v10 or v11
- At training step v12, trajectories from v10 are too old and discarded

### Coordination Flow

1. **Startup**: Trajectory collector starts generating trajectories in background
2. **Buffer Fill**: Training waits until buffer has sufficient trajectories
3. **Training Step**:
- Sample trajectories from buffer
- Run policy training
- Update weights and notify collector
4. **Weight Sync**: Collector pauses, waits for weight refit, then resumes
5. **Repeat**: Process continues with updated weights


### Architecture Diagram

The following sequence diagram illustrates the interactions between the three main components:

```
sequenceDiagram
participant Training as Training Loop
participant Collector as Trajectory Collector
participant Buffer as Replay Buffer

Note over Training, Buffer: Startup
Training->>Collector: Start generation
Training->>Buffer: Initialize

Note over Training, Buffer: Main Loop
loop Async Training
par Background Generation
Collector->>Buffer: Store trajectories
and Training Steps
Training->>Buffer: Sample trajectories
Buffer-->>Training: Return valid data
Training->>Training: Update policy weights
Training->>Collector: Sync new weights
end
end
```

## Usage Tips

1. **Buffer Sizing**: The replay buffer size is automatically calculated as:
```
buffer_size = num_prompts_per_step × max_trajectory_age_steps × 2
```

2. **Age Limits**: Start with `max_trajectory_age_steps: 1` and increase if needed for higher throughput

3. **Resource Allocation**: Ensure sufficient GPU memory for both the training and generation clusters
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ guides/environments.md
guides/eval.md
guides/deepseek.md
model-quirks.md
guides/async-grpo.md
```

```{toctree}
Expand Down
6 changes: 6 additions & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ grpo:
max_val_samples: 256
val_batch_size: 256
seed: 42
async_grpo:
enabled: false # Set to true to enable async training mode
# Max age (in training steps) for trajectories used in training
max_trajectory_age_steps: 1

loss_fn:
reference_policy_kl_penalty: 0.01
Expand All @@ -21,6 +25,8 @@ loss_fn:
ratio_clip_c: null
# (default off) loss formulation improvements (docs/guides/grpo.md#loss)
use_on_policy_kl_approximation: false
# Async GRPO requires importance sampling correction enabled
# Set to true when async_grpo.enabled is true
use_importance_sampling_correction: false
sequence_level_importance_ratios: false
token_level_loss: true
Expand Down
3 changes: 3 additions & 0 deletions examples/configs/grpo_math_1B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ grpo:
val_at_start: false
max_val_samples: 256
val_batch_size: 256
async_grpo:
enabled: false
max_trajectory_age_steps: 1

loss_fn:
reference_policy_kl_penalty: 0.01
Expand Down
3 changes: 3 additions & 0 deletions examples/configs/grpo_math_8B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ defaults: "grpo_math_1B.yaml"
grpo:
num_prompts_per_step: 64
num_generations_per_prompt: 32
async_grpo:
enabled: false
max_trajectory_age_steps: 1

policy:
model_name: "meta-llama/Llama-3.1-8B-Instruct"
Expand Down
3 changes: 3 additions & 0 deletions examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ grpo:
val_batch_size: 32
seed: 42
overlong_filtering: false
async_grpo:
enabled: false
max_trajectory_age_steps: 1

loss_fn:
reference_policy_kl_penalty: 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ grpo:
val_batch_size: 256
seed: 42
overlong_filtering: false
async_grpo:
enabled: false
max_trajectory_age_steps: 1

loss_fn:
reference_policy_kl_penalty: 0.01
ratio_clip_min: 0.2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ grpo:
val_batch_size: 256
seed: 42
overlong_filtering: false
async_grpo:
enabled: false
max_trajectory_age_steps: 1

loss_fn:
reference_policy_kl_penalty: 0.01
ratio_clip_min: 0.2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ grpo:
max_val_samples: 480
val_batch_size: 32
seed: 42
async_grpo:
enabled: false
max_trajectory_age_steps: 1

loss_fn:
reference_policy_kl_penalty: 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ grpo:
max_val_samples: 256
val_batch_size: 256
seed: 42
async_grpo:
enabled: false
max_trajectory_age_steps: 1

loss_fn:
reference_policy_kl_penalty: 0.01
ratio_clip_min: 0.2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ grpo:
max_val_samples: 256
val_batch_size: 256
seed: 42
async_grpo:
enabled: false
max_trajectory_age_steps: 1

loss_fn:
reference_policy_kl_penalty: 0.01
ratio_clip_min: 0.2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ grpo:
val_batch_size: 256
seed: 42
overlong_filtering: false
async_grpo:
enabled: false
max_trajectory_age_steps: 1

loss_fn:
reference_policy_kl_penalty: 0.01
ratio_clip_min: 0.2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ grpo:
val_batch_size: 256
seed: 42
overlong_filtering: false
async_grpo:
enabled: false
max_trajectory_age_steps: 1

loss_fn:
reference_policy_kl_penalty: 0.01
ratio_clip_min: 0.2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ grpo:
max_val_samples: 256
val_batch_size: 256
seed: 42
async_grpo:
enabled: false
max_trajectory_age_steps: 1

loss_fn:
reference_policy_kl_penalty: 0.01
ratio_clip_min: 0.2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ grpo:
max_val_samples: 256
val_batch_size: 256
seed: 42
async_grpo:
enabled: false
max_trajectory_age_steps: 1

loss_fn:
reference_policy_kl_penalty: 0.01
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ grpo:
val_at_start: false
max_val_samples: 256
val_batch_size: 256
async_grpo:
enabled: false
max_trajectory_age_steps: 1

loss_fn:
reference_policy_kl_penalty: 0.04
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ grpo:
val_batch_size: 256
seed: 42
overlong_filtering: false
async_grpo:
enabled: false
max_trajectory_age_steps: 1

loss_fn:
reference_policy_kl_penalty: 0.01
ratio_clip_min: 0.2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ grpo:
val_batch_size: 256
seed: 42
overlong_filtering: false
async_grpo:
enabled: false
max_trajectory_age_steps: 1

loss_fn:
reference_policy_kl_penalty: 0.01
ratio_clip_min: 0.2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ grpo:
val_batch_size: 256
seed: 42
overlong_filtering: false
async_grpo:
enabled: false
max_trajectory_age_steps: 1

loss_fn:
reference_policy_kl_penalty: 0.01
ratio_clip_min: 0.2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ grpo:
max_val_samples: 256
val_batch_size: 256
seed: 42
async_grpo:
enabled: false
max_trajectory_age_steps: 1

loss_fn:
reference_policy_kl_penalty: 0.01
ratio_clip_min: 0.2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ grpo:
val_batch_size: 256
seed: 42
overlong_filtering: false
async_grpo:
enabled: false
max_trajectory_age_steps: 1

loss_fn:
reference_policy_kl_penalty: 0.01
ratio_clip_min: 0.2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ grpo:
max_val_samples: 256
val_batch_size: 256
seed: 42
async_grpo:
enabled: false
max_trajectory_age_steps: 1

loss_fn:
reference_policy_kl_penalty: 0.01
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ grpo:
max_val_samples: 256
val_batch_size: 256
seed: 42
async_grpo:
enabled: false
max_trajectory_age_steps: 1

loss_fn:
reference_policy_kl_penalty: 0.01
Expand Down
Loading
Loading