diff --git a/docs/guides/async-grpo.md b/docs/guides/async-grpo.md new file mode 100644 index 0000000000..a5cd7d3ecb --- /dev/null +++ b/docs/guides/async-grpo.md @@ -0,0 +1,157 @@ +# Train with Async GRPO + +Async GRPO is an asynchronous training mode that allows trajectory generation and policy training to run concurrently, improving GPU utilization and throughput compared to synchronous GRPO. + +## Configure Async GRPO + +This section covers how to configure async GRPO by modifying your settings and includes a complete example configuration. +### Enable Async GRPO + +To use async GRPO, make these configuration changes: + +1. **Enable vLLM async engine**: +```yaml +policy: + generation: + backend: "vllm" + vllm_cfg: + async_engine: true +``` + +2. **Enable importance sampling correction** (required for convergence): +```yaml +loss_fn: + use_importance_sampling_correction: true +``` + +3. **Disable colocated inference** (required for async mode): +```yaml +policy: + generation: + colocated: + enabled: false + resources: + num_nodes: 1 # or more + gpus_per_node: 2 # adjust based on your setup +``` + +4. **Add async GRPO configuration**: +```yaml +grpo: + async_grpo: + max_trajectory_age_steps: 1 # Maximum age, in training steps, for trajectories +``` + +### Complete Example Config +```yaml +policy: + generation: + backend: "vllm" + colocated: + enabled: false + resources: + num_nodes: 1 + gpus_per_node: 2 + vllm_cfg: + async_engine: true + +loss_fn: + use_importance_sampling_correction: true + +grpo: + num_prompts_per_step: 32 + num_generations_per_prompt: 4 + async_grpo: + max_trajectory_age_steps: 1 + +cluster: + num_nodes: 2 + gpus_per_node: 4 +``` + +## Implementation Structure +This section covers the internal architecture of async GRPO and includes detailed explanations of how the core components interact. +### Core Components + +The async GRPO implementation consists of three main components: + +#### 1. Main Training Loop (`async_grpo_train` in `grpo.py`) +- Coordinates overall training process +- Samples trajectories from replay buffer +- Runs policy training steps +- Handles validation and checkpointing +- Manages weight synchronization between training and generation + +#### 2. Async Trajectory Collector (`AsyncTrajectoryCollector` in `async_utils.py`) +- Runs in background Ray actor +- Continuously generates trajectories using current policy weights +- Manages generation scheduling and weight version tracking +- Handles pause/resume for weight updates and validation +- Coordinates with replay buffer for trajectory storage + +#### 3. Replay Buffer (`ReplayBuffer` in `async_utils.py`) +- Stores generated trajectories with metadata +- Tracks weight versions for both generation and intended training use +- Implements age-based filtering to prevent stale trajectories +- Provides sampling interface for training steps + +### Weight Version Tracking + +Async GRPO uses a weight versioning system: +- **Generation Weight Version**: The policy weights used to generate a trajectory +- **Target Weight Version**: The training step where the trajectory will be used +- **Max Trajectory Age**: How many steps old a trajectory can be before being discarded + +Example with `max_trajectory_age_steps: 1`: +- Trajectory generated with weights v10 can be used for training steps v10 or v11 +- At training step v12, trajectories from v10 are too old and discarded + +### Coordination Flow + +1. **Startup**: Trajectory collector starts generating trajectories in background +2. **Buffer Fill**: Training waits until buffer has sufficient trajectories +3. **Training Step**: + - Sample trajectories from buffer + - Run policy training + - Update weights and notify collector +4. **Weight Sync**: Collector pauses, waits for weight refit, then resumes +5. **Repeat**: Process continues with updated weights + + +### Architecture Diagram + +The following sequence diagram illustrates the interactions between the three main components: + +``` +sequenceDiagram + participant Training as Training Loop + participant Collector as Trajectory Collector + participant Buffer as Replay Buffer + + Note over Training, Buffer: Startup + Training->>Collector: Start generation + Training->>Buffer: Initialize + + Note over Training, Buffer: Main Loop + loop Async Training + par Background Generation + Collector->>Buffer: Store trajectories + and Training Steps + Training->>Buffer: Sample trajectories + Buffer-->>Training: Return valid data + Training->>Training: Update policy weights + Training->>Collector: Sync new weights + end + end +``` + +## Usage Tips + +1. **Buffer Sizing**: The replay buffer size is automatically calculated as: + ``` + buffer_size = num_prompts_per_step Γ— max_trajectory_age_steps Γ— 2 + ``` + +2. **Age Limits**: Start with `max_trajectory_age_steps: 1` and increase if needed for higher throughput + +3. **Resource Allocation**: Ensure sufficient GPU memory for both the training and generation clusters diff --git a/docs/index.md b/docs/index.md index 763ce11fea..9c11b8febf 100644 --- a/docs/index.md +++ b/docs/index.md @@ -33,6 +33,7 @@ guides/environments.md guides/eval.md guides/deepseek.md model-quirks.md +guides/async-grpo.md ``` ```{toctree} diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 2238140b39..bc852de429 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -13,6 +13,10 @@ grpo: max_val_samples: 256 val_batch_size: 256 seed: 42 + async_grpo: + enabled: false # Set to true to enable async training mode + # Max age (in training steps) for trajectories used in training + max_trajectory_age_steps: 1 loss_fn: reference_policy_kl_penalty: 0.01 @@ -21,6 +25,8 @@ loss_fn: ratio_clip_c: null # (default off) loss formulation improvements (docs/guides/grpo.md#loss) use_on_policy_kl_approximation: false + # Async GRPO requires importance sampling correction enabled + # Set to true when async_grpo.enabled is true use_importance_sampling_correction: false sequence_level_importance_ratios: false token_level_loss: true diff --git a/examples/configs/grpo_math_1B_megatron.yaml b/examples/configs/grpo_math_1B_megatron.yaml index 511e38c5b5..04f5f5bc96 100644 --- a/examples/configs/grpo_math_1B_megatron.yaml +++ b/examples/configs/grpo_math_1B_megatron.yaml @@ -12,6 +12,9 @@ grpo: val_at_start: false max_val_samples: 256 val_batch_size: 256 + async_grpo: + enabled: false + max_trajectory_age_steps: 1 loss_fn: reference_policy_kl_penalty: 0.01 diff --git a/examples/configs/grpo_math_8B.yaml b/examples/configs/grpo_math_8B.yaml index ce46263375..51331ec509 100644 --- a/examples/configs/grpo_math_8B.yaml +++ b/examples/configs/grpo_math_8B.yaml @@ -4,6 +4,9 @@ defaults: "grpo_math_1B.yaml" grpo: num_prompts_per_step: 64 num_generations_per_prompt: 32 + async_grpo: + enabled: false + max_trajectory_age_steps: 1 policy: model_name: "meta-llama/Llama-3.1-8B-Instruct" diff --git a/examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml b/examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml index 491f787f7b..f6cc626890 100644 --- a/examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml +++ b/examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml @@ -13,6 +13,9 @@ grpo: val_batch_size: 32 seed: 42 overlong_filtering: false + async_grpo: + enabled: false + max_trajectory_age_steps: 1 loss_fn: reference_policy_kl_penalty: 0.0 diff --git a/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml b/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml index c240fa984c..091cb2909a 100644 --- a/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml +++ b/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml @@ -12,6 +12,10 @@ grpo: val_batch_size: 256 seed: 42 overlong_filtering: false + async_grpo: + enabled: false + max_trajectory_age_steps: 1 + loss_fn: reference_policy_kl_penalty: 0.01 ratio_clip_min: 0.2 diff --git a/examples/configs/recipes/llm/grpo-gemma3-27b-it-8n8g-fsdp2tp8-actckpt-long.yaml b/examples/configs/recipes/llm/grpo-gemma3-27b-it-8n8g-fsdp2tp8-actckpt-long.yaml index ac8069c79b..4c3351970c 100644 --- a/examples/configs/recipes/llm/grpo-gemma3-27b-it-8n8g-fsdp2tp8-actckpt-long.yaml +++ b/examples/configs/recipes/llm/grpo-gemma3-27b-it-8n8g-fsdp2tp8-actckpt-long.yaml @@ -12,6 +12,10 @@ grpo: val_batch_size: 256 seed: 42 overlong_filtering: false + async_grpo: + enabled: false + max_trajectory_age_steps: 1 + loss_fn: reference_policy_kl_penalty: 0.01 ratio_clip_min: 0.2 diff --git a/examples/configs/recipes/llm/grpo-gspo-deepscaler-1.5b-8K.yaml b/examples/configs/recipes/llm/grpo-gspo-deepscaler-1.5b-8K.yaml index 2d0d8b2602..e1b7c4d809 100644 --- a/examples/configs/recipes/llm/grpo-gspo-deepscaler-1.5b-8K.yaml +++ b/examples/configs/recipes/llm/grpo-gspo-deepscaler-1.5b-8K.yaml @@ -13,6 +13,9 @@ grpo: max_val_samples: 480 val_batch_size: 32 seed: 42 + async_grpo: + enabled: false + max_trajectory_age_steps: 1 loss_fn: reference_policy_kl_penalty: 0.0 diff --git a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8.yaml b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8.yaml index 454583d04a..81ca15f6bd 100644 --- a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8.yaml @@ -12,6 +12,10 @@ grpo: max_val_samples: 256 val_batch_size: 256 seed: 42 + async_grpo: + enabled: false + max_trajectory_age_steps: 1 + loss_fn: reference_policy_kl_penalty: 0.01 ratio_clip_min: 0.2 diff --git a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated.yaml b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated.yaml index 2ad32d599f..17b474bd72 100644 --- a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated.yaml @@ -12,6 +12,10 @@ grpo: max_val_samples: 256 val_batch_size: 256 seed: 42 + async_grpo: + enabled: false + max_trajectory_age_steps: 1 + loss_fn: reference_policy_kl_penalty: 0.01 ratio_clip_min: 0.2 diff --git a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml index bc4c1cc51c..1c2b3840ca 100644 --- a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml @@ -12,6 +12,10 @@ grpo: val_batch_size: 256 seed: 42 overlong_filtering: false + async_grpo: + enabled: false + max_trajectory_age_steps: 1 + loss_fn: reference_policy_kl_penalty: 0.01 ratio_clip_min: 0.2 diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml index 8f8c2686eb..eddf09bf97 100644 --- a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -12,6 +12,10 @@ grpo: val_batch_size: 256 seed: 42 overlong_filtering: false + async_grpo: + enabled: false + max_trajectory_age_steps: 1 + loss_fn: reference_policy_kl_penalty: 0.01 ratio_clip_min: 0.2 diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.yaml index deca81ed9c..4ad29901fa 100755 --- a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.yaml @@ -12,6 +12,10 @@ grpo: max_val_samples: 256 val_batch_size: 256 seed: 42 + async_grpo: + enabled: false + max_trajectory_age_steps: 1 + loss_fn: reference_policy_kl_penalty: 0.01 ratio_clip_min: 0.2 diff --git a/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml b/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml index a703493fed..507b1eefd8 100644 --- a/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml +++ b/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml @@ -21,6 +21,9 @@ grpo: max_val_samples: 256 val_batch_size: 256 seed: 42 + async_grpo: + enabled: false + max_trajectory_age_steps: 1 loss_fn: reference_policy_kl_penalty: 0.01 diff --git a/examples/configs/recipes/llm/grpo-moonlight-16ba3b-4n8g-megatron.yaml b/examples/configs/recipes/llm/grpo-moonlight-16ba3b-4n8g-megatron.yaml index 8944064ba4..a0784ba746 100644 --- a/examples/configs/recipes/llm/grpo-moonlight-16ba3b-4n8g-megatron.yaml +++ b/examples/configs/recipes/llm/grpo-moonlight-16ba3b-4n8g-megatron.yaml @@ -12,6 +12,9 @@ grpo: val_at_start: false max_val_samples: 256 val_batch_size: 256 + async_grpo: + enabled: false + max_trajectory_age_steps: 1 loss_fn: reference_policy_kl_penalty: 0.04 diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml index 59b5a453af..7fd4007279 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml @@ -12,6 +12,10 @@ grpo: val_batch_size: 256 seed: 42 overlong_filtering: false + async_grpo: + enabled: false + max_trajectory_age_steps: 1 + loss_fn: reference_policy_kl_penalty: 0.01 ratio_clip_min: 0.2 diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml index 1d51d46bb3..f163092404 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml @@ -12,6 +12,10 @@ grpo: val_batch_size: 256 seed: 42 overlong_filtering: false + async_grpo: + enabled: false + max_trajectory_age_steps: 1 + loss_fn: reference_policy_kl_penalty: 0.01 ratio_clip_min: 0.2 diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml index 20ef24f75d..f6ecc1e390 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml @@ -12,6 +12,10 @@ grpo: val_batch_size: 256 seed: 42 overlong_filtering: false + async_grpo: + enabled: false + max_trajectory_age_steps: 1 + loss_fn: reference_policy_kl_penalty: 0.01 ratio_clip_min: 0.2 diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-megatron.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-megatron.yaml index 3d76865969..1209040cda 100755 --- a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-megatron.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-megatron.yaml @@ -12,6 +12,10 @@ grpo: max_val_samples: 256 val_batch_size: 256 seed: 42 + async_grpo: + enabled: false + max_trajectory_age_steps: 1 + loss_fn: reference_policy_kl_penalty: 0.01 ratio_clip_min: 0.2 diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml index cc14c07a8b..b8f79eb6ae 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -12,6 +12,10 @@ grpo: val_batch_size: 256 seed: 42 overlong_filtering: false + async_grpo: + enabled: false + max_trajectory_age_steps: 1 + loss_fn: reference_policy_kl_penalty: 0.01 ratio_clip_min: 0.2 diff --git a/examples/configs/recipes/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml b/examples/configs/recipes/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml index 0c6786cc2d..3f744e1a30 100644 --- a/examples/configs/recipes/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml +++ b/examples/configs/recipes/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml @@ -13,6 +13,9 @@ grpo: max_val_samples: 256 val_batch_size: 256 seed: 42 + async_grpo: + enabled: false + max_trajectory_age_steps: 1 loss_fn: reference_policy_kl_penalty: 0.01 diff --git a/examples/configs/recipes/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml b/examples/configs/recipes/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml index d3dcc466d7..66feabee46 100644 --- a/examples/configs/recipes/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml +++ b/examples/configs/recipes/vlm/vlm_grpo-smolvlm2-2.2b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml @@ -13,6 +13,9 @@ grpo: max_val_samples: 256 val_batch_size: 256 seed: 42 + async_grpo: + enabled: false + max_trajectory_age_steps: 1 loss_fn: reference_policy_kl_penalty: 0.01 diff --git a/examples/configs/vlm_grpo_3B.yaml b/examples/configs/vlm_grpo_3B.yaml index c71eb78e12..1e4da58985 100644 --- a/examples/configs/vlm_grpo_3B.yaml +++ b/examples/configs/vlm_grpo_3B.yaml @@ -14,6 +14,9 @@ grpo: max_val_samples: 256 val_batch_size: 256 seed: 42 + async_grpo: + enabled: false + max_trajectory_age_steps: 1 loss_fn: reference_policy_kl_penalty: 0.01 diff --git a/examples/run_grpo_math.py b/examples/run_grpo_math.py index 5f446b08a1..51adfddced 100644 --- a/examples/run_grpo_math.py +++ b/examples/run_grpo_math.py @@ -188,20 +188,47 @@ def main() -> None: master_config, ) = setup(config, tokenizer, dataset, val_dataset) - grpo_train( - policy, - policy_generation, - dataloader, - val_dataloader, - tokenizer, - loss_fn, - task_to_env, - val_task_to_env, - logger, - checkpointer, - grpo_state, - master_config, - ) + # Check if async mode is enabled + if "async_grpo" in config["grpo"] and config["grpo"]["async_grpo"]["enabled"]: + from nemo_rl.algorithms.grpo import async_grpo_train + + print("πŸš€ Running async GRPO training") + + async_config = config["grpo"]["async_grpo"] + # Run async GRPO training + async_grpo_train( + policy=policy, + policy_generation=policy_generation, + dataloader=dataloader, + val_dataloader=val_dataloader, + tokenizer=tokenizer, + loss_fn=loss_fn, + task_to_env=task_to_env, + val_task_to_env=val_task_to_env, + logger=logger, + checkpointer=checkpointer, + grpo_save_state=grpo_state, + master_config=master_config, + max_trajectory_age_steps=async_config["max_trajectory_age_steps"], + ) + else: + print("πŸš€ Running synchronous GRPO training") + + # Run standard GRPO training + grpo_train( + policy, + policy_generation, + dataloader, + val_dataloader, + tokenizer, + loss_fn, + task_to_env, + val_task_to_env, + logger, + checkpointer, + grpo_state, + master_config, + ) if __name__ == "__main__": diff --git a/nemo_rl/algorithms/async_utils.py b/nemo_rl/algorithms/async_utils.py new file mode 100644 index 0000000000..55703a27b1 --- /dev/null +++ b/nemo_rl/algorithms/async_utils.py @@ -0,0 +1,674 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading as _threading +import time +from typing import Any, Optional + +import ray +from torchdata.stateful_dataloader import StatefulDataLoader +from transformers import PreTrainedTokenizerBase + +from nemo_rl.algorithms.grpo import MasterConfig +from nemo_rl.data.interfaces import DatumSpec +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.environments.interfaces import EnvironmentInterface +from nemo_rl.experience.rollouts import ( + run_async_multi_turn_rollout, +) +from nemo_rl.models.generation.interfaces import GenerationInterface + +TokenizerType = PreTrainedTokenizerBase + + +@ray.remote # pragma: no cover +class ReplayBuffer: + """Replay buffer storing per-prompt groups. + + A single entry corresponds to 1 prompt repeated by + grpo.num_generations_per_prompt (required to compute per-prompt advantages). + """ + + def __init__(self, max_size: int): + if max_size <= 0: + raise ValueError(f"max_size must be positive, got {max_size}") + self.max_size = max_size + self.trajectories = [] + # If trajectory_version is 1 and target_weight_version is 4 it means that weight version 1 was used for generating a trajectory and this trajectory will be used for training when weight version is 4. + self.trajectory_versions = [] # it is the weight-version used for generation of a trajectory + self.target_weight_versions = [] # it is the weight-version of the trainer where this trajectory will be used. + + self.last_target_weight_already_generated = -1 + self._lock = _threading.Lock() + + def push_with_wait_signal( + self, + trajectory: dict[str, Any], + weight_version: int, + target_weight_version: int, + ) -> str: + """Add a per-prompt trajectory group with metadata. + + Args: + trajectory: data dict + weight_version: version of the model weights used for generation + target_weight_version: version of the model weights this trajectory is intended for training + """ + with self._lock: + if len(self.trajectories) >= self.max_size: + return "full" + + print("πŸ” ReplayBuffer.push_with_wait_signal: Adding trajectory") + self.trajectories.append(trajectory) + self.trajectory_versions.append(weight_version) + self.target_weight_versions.append(target_weight_version) + self.last_target_weight_already_generated = max( + self.last_target_weight_already_generated, target_weight_version + ) + print( + f"ReplayBuffer state: {len(self.trajectories)} groups, versions={self.trajectory_versions}, targets={self.target_weight_versions}, last_target_weight_already_generated={self.last_target_weight_already_generated}" + ) + return "success" + + def get_debug_info(self) -> dict: + """Get debug information about buffer state.""" + return { + "total_trajectories": len(self.trajectories), + "trajectory_versions": self.trajectory_versions, + "target_weight_versions": self.target_weight_versions, + "max_size": self.max_size, + } + + def get_last_target_weight_already_generated(self) -> int: + with self._lock: + return self.last_target_weight_already_generated + + def get_existing_target_weights(self) -> set[int]: + """Get set of target weight versions that already have trajectories.""" + with self._lock: + return set(self.target_weight_versions) + + def sample( + self, + num_prompt_groups: int, + current_weight_version: int, + max_age_steps: int, + ) -> Optional[dict[str, Any]]: + """Sample per-prompt trajectory groups intended for the current training step. + + Only returns trajectories with target_weight_version == current_weight_version. + If insufficient trajectories are available, returns None to stall training + until the remaining trajectories are generated. This ensures no trajectory + loses its last chance to be used for its intended training step. + + Returns: + Dictionary with 'trajectories' and 'avg_trajectory_age' keys, or None if insufficient data + """ + with self._lock: + if not self.trajectories: + return None + + total_trajectories = len(self.trajectories) + print("πŸ” ReplayBuffer sampling debug:") + print(f" {current_weight_version=}, {max_age_steps=}") + print(f" {self.trajectory_versions=}") + + # For debugging: check for unexpected old trajectories + from collections import Counter + + version_counts = Counter(self.trajectory_versions) + print(f" {version_counts=}") + + # Compute minimum valid version based on age window + # max_age_steps=1 means trajectories from the last 1 step are valid + min_valid_version = max(0, current_weight_version - max_age_steps) + print(f" {min_valid_version=}") + + # Check for unexpected old trajectories + old_trajectories = [ + v for v in self.trajectory_versions if v < min_valid_version + ] + if old_trajectories: + raise ValueError( + f"Found {len(old_trajectories)} trajectories older than min_valid_version {min_valid_version}" + ) + + # Filter for valid trajectories without modifying the buffer + valid_indices = [ + i + for i, v in enumerate(self.trajectory_versions) + if min_valid_version <= v <= current_weight_version + ] + print( + f" valid_indices: {len(valid_indices)}/{total_trajectories} trajectories within age window" + ) + if not valid_indices: + print("No trajectories available for sampling.") + return None + + # Enforce exact number of groups if available; otherwise, signal to wait + if len(valid_indices) < num_prompt_groups: + print( + f"Insufficient valid groups: have {len(valid_indices)}, need {num_prompt_groups}. Waiting for buffer to fill." + ) + return None + + # Only select trajectories intended for the current training step + # This ensures no trajectory loses its "last chance" to be used for its intended step + intended_indices = [ + i + for i in valid_indices + if self.target_weight_versions[i] == current_weight_version + ] + + print( + f" 🎯 Found {len(intended_indices)} trajectories intended for current step {current_weight_version}" + ) + + # Stall training if we don't have enough trajectories intended for this step + if len(intended_indices) < num_prompt_groups: + print( + f" ⏸️ STALLING: Need {num_prompt_groups} trajectories for step {current_weight_version}, but only {len(intended_indices)} are ready" + ) + print( + f" ⏸️ Training will wait for remaining {num_prompt_groups - len(intended_indices)} trajectories to be generated" + ) + return None + + # Select exactly the trajectories intended for this step (FIFO within same target) + selected: list[int] = intended_indices[:num_prompt_groups] + print( + f" βœ… Selected {len(selected)} trajectories all intended for step {current_weight_version}" + ) + + from collections import Counter + + sampled_weights = [self.trajectory_versions[i] for i in selected] + avg_trajectory_age = current_weight_version - sum(sampled_weights) / len( + sampled_weights + ) + print( + f"βœ… Selected counts by generation weight-version: {Counter(sampled_weights)}" + ) + print(f"πŸ“Š Average trajectory age: {avg_trajectory_age:.2f} steps") + print( + f"🎯 All selected trajectories target step {current_weight_version} (100% target match)" + ) + + sampled_items = [self.trajectories[i] for i in selected] + + # Remove selected items in reverse order to maintain correct indices + for idx in sorted(selected, reverse=True): + self.trajectory_versions.pop(idx) + self.target_weight_versions.pop(idx) + self.trajectories.pop(idx) + print( + f"πŸ—‘οΈ Consumed and removed {len(selected)} groups from buffer, old buffer size: {total_trajectories}, new buffer size: {len(self.trajectories)}, new target weight versions {self.target_weight_versions}" + ) + + return { + "trajectories": sampled_items, + "avg_trajectory_age": avg_trajectory_age, + } + + def size(self) -> int: + """Return current buffer size.""" + with self._lock: + return len(self.trajectories) + + def clear(self) -> None: + """Clear the buffer.""" + with self._lock: + self.trajectories.clear() + self.trajectory_versions.clear() + self.target_weight_versions.clear() + + +@ray.remote # pragma: no cover +class AsyncTrajectoryCollector: + """Collects trajectories asynchronously and adds them to replay buffer.""" + + def __init__( + self, + policy_generation: GenerationInterface, + tokenizer: TokenizerType, + task_to_env: dict[str, EnvironmentInterface], + master_config: MasterConfig, + replay_buffer: Any, + start_step: int = 0, + ): + self.policy_generation = policy_generation + self.tokenizer = tokenizer + self.task_to_env = task_to_env + self.master_config = master_config + self.replay_buffer = replay_buffer + self.running = False + + self._pg_lock: _threading.Lock = _threading.Lock() + + # Event for manual pause/resume control + self._manual_pause_cleared = _threading.Event() + self._manual_pause_cleared.set() + + self._refit_pause_cleared = _threading.Event() + self._refit_pause_cleared.set() # Start in cleared state + + self.current_weight_version: int = start_step + self.initial_weight_version: int = start_step + + # Track when generation limits cause collection to pause + self._last_limit_warning_version = None + + # Event to signal when generation limits are cleared (more efficient than polling) + self._generation_limit_cleared = _threading.Event() + self._generation_limit_cleared.set() # Start in cleared state + + # Track threads + self._inflight_threads: set[_threading.Thread] = set() + self._threads_lock: _threading.Lock = _threading.Lock() + + # Limit in-flight generator requests to num_prompts_per_step + max_inflight = int(self.master_config["grpo"]["num_prompts_per_step"]) or 1 + self._inflight_sema = _threading.Semaphore(max_inflight) + + # Simple lock to prevent race conditions when checking/spawning workers + self._generation_check_lock: _threading.Lock = _threading.Lock() + # Track which target weights are currently being generated (globally) + self._generating_targets: set[int] = set() + + def _calculate_target_weights(self, generation_weight_version: int) -> list[int]: + """Calculate target weight versions for given generation weight version. + + The list of versions returned enumerate the possible version a generation + server can target. These versions are looped over to see what training + step they can target. If all target versions are exhausted, this generation + server will remain idle until the next weight update. + + Example: + generation_weight_version = 10 + max_trajectory_age_steps = 4 + + Returns: + [11, 12, 13, 14] # Meaning this generation server can create trajectories for training step 11, 12, 13, 14 + """ + # Read async config strictly from grpo.async_grpo + async_cfg = self.master_config.get("grpo", {}).get("async_grpo", {}) + max_trajectory_age = async_cfg["max_trajectory_age_steps"] + if generation_weight_version == self.initial_weight_version: + return [ + i + for i in range( + self.initial_weight_version, + self.initial_weight_version + max_trajectory_age + 1, + ) + ] + + return [generation_weight_version + i for i in range(1, max_trajectory_age + 1)] + + def _get_next_target_for_generation( + self, generation_weight_version: int + ) -> Optional[int]: + """Get the next target weight that needs generation (if any).""" + target_weights = self._calculate_target_weights(generation_weight_version) + last_target_weight_already_generated = ray.get( + self.replay_buffer.get_last_target_weight_already_generated.remote() + ) + + with self._generation_check_lock: + for target_weight in target_weights: + if ( + target_weight > last_target_weight_already_generated + and target_weight not in self._generating_targets + ): + self._generating_targets.add(target_weight) + print(f"🎯 Reserved target weight {target_weight} for generation") + return target_weight + + return None + + def set_weight_version(self, version: int) -> None: + self.current_weight_version = version + + # Resume collection if it was paused due to generation limits + was_paused = not self._generation_limit_cleared.is_set() + if was_paused: + self._generation_limit_cleared.set() # Signal that collection can resume + print(f"πŸ”„ Updated weight version to {version}, resuming collection") + else: + print(f"πŸ”„ Updated weight version to {version}") + + def _should_pause_for_generation_limits(self) -> bool: + """Check if collection should be paused due to generation limits.""" + try: + target_weights = self._calculate_target_weights(self.current_weight_version) + last_target_weight_already_generated = ray.get( + self.replay_buffer.get_last_target_weight_already_generated.remote() + ) + + # Check if any target weight in our range needs generation + with self._generation_check_lock: + for target_weight in target_weights: + if ( + target_weight > last_target_weight_already_generated + and target_weight not in self._generating_targets + ): + return False # Found a target that needs generation + + print( + f"⏸️ All target weights {target_weights} already generated or in progress, pausing" + ) + return True + except Exception: + return False + + def start_collection(self, dataloader: StatefulDataLoader) -> None: + """Start collecting trajectories from dataloader.""" + self.running = True + self.dataloader = dataloader + + print("Started continuous trajectory collection") + + self.collection_thread = _threading.Thread(target=self._collection_loop) + self.collection_thread.daemon = True + self.collection_thread.start() + + print("Collection thread started, start_collection returning") + + def _collection_loop(self): + """Run the collection loop in background thread.""" + try: + for batch in self.dataloader: + if not self.running: + break + + # Check if manually paused and wait + if not self._manual_pause_cleared.is_set() and self.running: + self._manual_pause_cleared.wait() + + # Check if refit is in progress and wait + if not self._refit_pause_cleared.is_set() and self.running: + print("⏸️ Pausing collection for refit...") + self._refit_pause_cleared.wait() + print("▢️ Refit completed, resuming collection") + + # Check if generation limits require pausing collection + if self._should_pause_for_generation_limits() and self.running: + # Only log warning once per weight version + if self._last_limit_warning_version != self.current_weight_version: + async_cfg = self.master_config.get("grpo", {}).get( + "async_grpo", {} + ) + max_trajectory_age = async_cfg["max_trajectory_age_steps"] + target_weights = [ + self.current_weight_version + i + for i in range(max_trajectory_age) + ] + + print( + f"⏸️ Pausing collection: all target weights {target_weights} for weight version {self.current_weight_version} " + f"already exist in buffer. Waiting for weight update..." + ) + self._last_limit_warning_version = self.current_weight_version + + self._generation_limit_cleared.clear() # Clear the event to pause + + # Efficiently wait for generation limits to be cleared (no polling!) + self._generation_limit_cleared.wait() + + # Double-check we're still running after being woken up + if not self.running: + break + + if not self.running: + break + + self._process_batch(batch) + + except Exception as e: + print(f"❌ Error in trajectory collection: {e}") + import traceback + + traceback.print_exc() + finally: + self.running = False + print("πŸ›‘ Trajectory collection stopped") + + def _process_batch(self, batch: BatchedDataDict[DatumSpec]) -> None: + """Process a single batch and generate for one target weight.""" + try: + generation_weight_version = self.current_weight_version + num_generations = self.master_config["grpo"]["num_generations_per_prompt"] + num_prompts = batch.size + + # Get the next target weight that needs generation + target_weight = self._get_next_target_for_generation( + generation_weight_version + ) + + if target_weight is None: + print( + f"πŸ”„ No targets need generation for weight {generation_weight_version}" + ) + return + + print( + f"🎯 Generating for target weight {target_weight} from generation_weight_version {generation_weight_version}" + ) + + # Generate for all prompts in this batch for the target weight + for prompt_idx in range(num_prompts): + # Wait for refit to complete if in progress + if not self._refit_pause_cleared.is_set() and self.running: + with self._threads_lock: + active_threads = len(self._inflight_threads) + print( + f"⏸️ Waiting for refit to complete before starting new generation ({active_threads} threads still active)" + ) + self._refit_pause_cleared.wait() + + # After refit finishes if weight version has updated, reflect that in the new trajectories + generation_weight_version = self.current_weight_version + + single_prompt_batch = batch.slice(prompt_idx, prompt_idx + 1) + repeated_batch = single_prompt_batch.repeat_interleave(num_generations) + + self._inflight_sema.acquire() + worker = _threading.Thread( + target=self._run_prompt_group_worker, + args=( + repeated_batch, + generation_weight_version, + target_weight, + prompt_idx, + ), + daemon=True, + ) + with self._threads_lock: + self._inflight_threads.add(worker) + worker.start() + + self._cleanup_finished_threads() + + except Exception as e: + print(f"❌ Error processing batch: {e}") + import traceback + + traceback.print_exc() + + def get_weight_version(self) -> int: + return self.current_weight_version + + def pause(self) -> None: + """Pause trajectory collection.""" + self._manual_pause_cleared.clear() # Signal collection to pause + print("Trajectory collection paused") + + def resume(self) -> None: + """Resume trajectory collection.""" + self._manual_pause_cleared.set() # Signal collection to resume + print("Trajectory collection resumed") + + def prepare_for_refit(self) -> None: + """Pause new generation starts and wait for pending generations to complete before refit.""" + start_time = time.time() + print("πŸ”„ Preparing for refit: pausing new generations...") + + # Pause new generation starts + self._refit_pause_cleared.clear() + print("⏸️ New generation starts paused") + + # Wait for all pending generations to complete + # Note that is suboptimal for async performance and will be fixed in a follow-up PR where two more options will be added: + # 1. Pause the generations at their current decoding step, update the weights and continue with decoding. + # 2. Stop the current generations, store in a buffer and resume them in next iteration with new weights. + self.wait_for_pending_generations() + + elapsed = time.time() - start_time + print( + f"βœ… All pending generations completed, ready for refit (took {elapsed:.2f}s)" + ) + + def resume_after_refit(self) -> None: + """Resume new generation starts after refit is complete.""" + print("πŸ”„ Resuming generation starts after refit") + self._refit_pause_cleared.set() + + def wait_for_pending_generations(self) -> None: + """Wait for all in-flight generation threads to complete.""" + start_time = time.time() + + while True: + with self._threads_lock: + finished = {t for t in self._inflight_threads if not t.is_alive()} + for t in finished: + self._inflight_threads.remove(t) + + pending_count = len(self._inflight_threads) + + if pending_count == 0: + print("βœ… All generation threads completed") + break + + elapsed = time.time() - start_time + print( + f"⏳ Waiting for {pending_count} pending generation threads... ({elapsed:.1f}s elapsed)" + ) + time.sleep(0.5) + + def get_dataloader_state(self) -> dict: + """Get the current dataloader state for checkpointing.""" + if hasattr(self, "dataloader") and hasattr(self.dataloader, "state_dict"): + return self.dataloader.state_dict() + return {} + + def _cleanup_finished_threads(self) -> None: + with self._threads_lock: + finished = {t for t in self._inflight_threads if not t.is_alive()} + for t in finished: + self._inflight_threads.remove(t) + + def _run_prompt_group_worker( + self, + repeated_batch: BatchedDataDict[DatumSpec], + generation_weight_version: int, + target_weight_version: int, + prompt_idx: int, + ) -> None: + try: + # Run rollout for this prompt group + # Async engine supports concurrent generation; avoid locking + final_batch, rollout_metrics = run_async_multi_turn_rollout( + policy_generation=self.policy_generation, + input_batch=repeated_batch, + tokenizer=self.tokenizer, + task_to_env=self.task_to_env, + max_seq_len=self.master_config["policy"]["max_total_sequence_length"], + max_rollout_turns=self.master_config["grpo"]["max_rollout_turns"], + greedy=False, + ) + + # Move to CPU and push to buffer (avoid blocking on GC/push) + final_batch_cpu = final_batch.to("cpu") + del final_batch + + trajectory_group = { + "batch": final_batch_cpu, + "rollout_metrics": rollout_metrics, + "timestamp": time.time(), + } + + # Use exponential backoff when buffer is full + try: + backoff_delay = 0.01 + while self.running: + status = ray.get( + self.replay_buffer.push_with_wait_signal.remote( + trajectory_group, + generation_weight_version, + target_weight_version, + ) + ) + if status == "success": + print( + f"πŸ“¦ Buffered per-prompt group (prompt_idx {prompt_idx}, target_weight {target_weight_version})" + ) + + # Release reservation when FIRST prompt group for this target is successfully buffered + if prompt_idx == 0: + with self._generation_check_lock: + if target_weight_version in self._generating_targets: + self._generating_targets.discard( + target_weight_version + ) + print( + f"🧹 Released reservation for target weight {target_weight_version} (first prompt buffered)" + ) + break + elif status == "full": + # Exponential backoff up to 1 second + time.sleep(min(backoff_delay, 1.0)) + backoff_delay *= 1.5 + else: + # Unexpected status, wait briefly + time.sleep(0.01) + except Exception as e: + print(f"❌ Failed to enqueue per-prompt group to buffer: {e}") + import traceback + + traceback.print_exc() + except Exception as e: + print(f"❌ Error in prompt group worker: {e}") + import traceback + + traceback.print_exc() + finally: + # Clean up reservation in case of error (if not already cleaned up) + with self._generation_check_lock: + if target_weight_version in self._generating_targets: + self._generating_targets.discard(target_weight_version) + print( + f"🧹 Emergency cleanup: Released reservation for target weight {target_weight_version}" + ) + + # Detach thread record when finished + with self._threads_lock: + current = _threading.current_thread() + if current in self._inflight_threads: + self._inflight_threads.remove(current) + try: + self._inflight_sema.release() + except Exception: + import traceback + + traceback.print_exc() diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index d64dad01de..190f3c2921 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import gc import os +import time import warnings from contextlib import nullcontext from pathlib import Path @@ -40,6 +42,7 @@ get_keys_from_message_log, ) from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster from nemo_rl.environments.interfaces import EnvironmentInterface from nemo_rl.experience.rollouts import ( @@ -59,6 +62,7 @@ ) from nemo_rl.utils.nsys import maybe_gpu_profile_step from nemo_rl.utils.timer import TimeoutChecker, Timer +from nemo_rl.utils.venvs import create_local_venv_on_each_node # =============================================================================== # Configuration @@ -66,6 +70,14 @@ TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase) +class AsyncGRPOConfig(TypedDict): + enabled: bool + # Maximum trajectory age in training steps for samples drawn from the + # async replay buffer. Trajectories older than this are excluded during + # sampling; buffer sizing also scales with this value. + max_trajectory_age_steps: int + + class GRPOConfig(TypedDict): num_prompts_per_step: int num_generations_per_prompt: int @@ -79,6 +91,7 @@ class GRPOConfig(TypedDict): val_at_start: bool max_val_samples: int seed: int + async_grpo: NotRequired[AsyncGRPOConfig] overlong_filtering: NotRequired[bool] @@ -1121,4 +1134,640 @@ def validate( # Make sure to reset the timer after validation timer.reset() + # Explicit GPU memory cleanup after validation + gc.collect() + torch.cuda.empty_cache() + return val_metrics, timing_metrics + + +def async_grpo_train( + policy: ColocatablePolicyInterface, + policy_generation: Optional[GenerationInterface], + dataloader: StatefulDataLoader, + val_dataloader: Optional[StatefulDataLoader], + tokenizer: TokenizerType, + loss_fn: LossFunction, + task_to_env: dict[str, EnvironmentInterface], + val_task_to_env: Optional[dict[str, EnvironmentInterface]], + logger: Logger, + checkpointer: CheckpointManager, + grpo_save_state: GRPOSaveState, + master_config: MasterConfig, + max_trajectory_age_steps: int = 1, +) -> None: + """Run asynchronous GRPO training with replay buffer. + + Args: + policy: Training policy + policy_generation: Generation interface + dataloader: Training data loader + val_dataloader: Validation data loader + tokenizer: Tokenizer + loss_fn: Loss function + task_to_env: Training environments + val_task_to_env: Validation environments + logger: Logger + checkpointer: Checkpoint manager + grpo_save_state: Training state + master_config: Master configuration + max_trajectory_age_steps: Maximum age (in training steps) for trajectories to be used in training + """ + # Ensure we are running with a compatible async generation backend + assert _should_use_async_rollouts(master_config), ( + "Async GRPO requires vLLM backend with vllm_cfg.async_engine=True. " + "Set policy.generation.vllm_cfg.async_engine to true in your config." + ) + assert master_config["loss_fn"]["use_importance_sampling_correction"] is True, ( + "Importance sampling correction must be enabled for async GRPO for good convergence due to off-policy samples!" + ) + # Import async utilities only when needed + from nemo_rl.algorithms.async_utils import AsyncTrajectoryCollector, ReplayBuffer + + timer = Timer() + NEED_REFIT = True + + # Setup generation interface + if policy_generation is None: + policy_generation = policy + NEED_REFIT = False + POLICY_GENERATION_STALE = True + assert policy_generation is not None + + # Training state + step = grpo_save_state["current_step"] + weight_version = step # Tracks refitted weight versions + consumed_samples = grpo_save_state["consumed_samples"] + val_period = master_config["grpo"]["val_period"] + val_at_start = master_config["grpo"]["val_at_start"] + colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"] + + assert not colocated_inference, ( + "Colocated inference is not supported for async GRPO. Please use non-colocated inference." + ) + + # Calculate minimum buffer size from training requirements + # In per-prompt buffer mode, one buffer entry is 1 prompt * num_generations_per_prompt + num_prompts_per_step = master_config["grpo"]["num_prompts_per_step"] + samples_per_prompt_group = master_config["grpo"]["num_generations_per_prompt"] + train_gbs = master_config["policy"]["train_global_batch_size"] + + # Ensure the buffer has at least one step worth of prompt-groups before training + min_trajectories_needed = num_prompts_per_step + + print("πŸ“Š Buffer requirements calculation:") + print(f" - num_prompts_per_step: {num_prompts_per_step}") + print(f" - num_generations_per_prompt: {samples_per_prompt_group}") + print(f" - samples_per_prompt_group: {samples_per_prompt_group}") + print(f" - train_global_batch_size: {train_gbs}") + print(f" - min_trajectories_needed: {min_trajectories_needed} (async mode)") + + _replay_py_exec = get_actor_python_env( + "nemo_rl.algorithms.async_utils.ReplayBuffer" + ) + if _replay_py_exec.startswith("uv"): + # Lazily build a dedicated venv across all Ray nodes on-demand. + _replay_py_exec = create_local_venv_on_each_node( + _replay_py_exec, + "nemo_rl.algorithms.async_utils.ReplayBuffer", + ) + + _replay_runtime_env = { + "py_executable": _replay_py_exec, + "env_vars": { + **os.environ, + "VIRTUAL_ENV": _replay_py_exec, + "UV_PROJECT_ENVIRONMENT": _replay_py_exec, + }, + } + + # Calculate optimal buffer size based on generation limits to prevent length bias + # Each weight version generates exactly num_prompts_per_step trajectories + # With max_age_steps, we keep trajectories from multiple weight versions + num_prompts_per_step = master_config["grpo"]["num_prompts_per_step"] + late_arrival_slack = 2 + optimal_buffer_size = ( + num_prompts_per_step * max_trajectory_age_steps * late_arrival_slack + ) + + replay_buffer = ReplayBuffer.options(runtime_env=_replay_runtime_env).remote( + max_size=optimal_buffer_size + ) + + _tc_py_exec = get_actor_python_env( + "nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector" + ) + if _tc_py_exec.startswith("uv"): + _tc_py_exec = create_local_venv_on_each_node( + _tc_py_exec, + "nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector", + ) + + _tc_runtime_env = { + "py_executable": _tc_py_exec, + "env_vars": { + **os.environ, + "VIRTUAL_ENV": _tc_py_exec, + "UV_PROJECT_ENVIRONMENT": _tc_py_exec, + }, + } + + # Initialize trajectory collector with synchronized collection + trajectory_collector = AsyncTrajectoryCollector.options( + runtime_env=_tc_runtime_env + ).remote( + policy_generation=policy_generation, + tokenizer=tokenizer, + task_to_env=task_to_env, + master_config=master_config, + replay_buffer=replay_buffer, + start_step=step, + ) + + # Start trajectory collection in background + collection_task = trajectory_collector.start_collection.remote(dataloader) + + # Ensure collector knows initial weight version + trajectory_collector.set_weight_version.remote(weight_version) + + print("πŸ“¦ Started continuous background trajectory collection") + + print( + f"πŸš€ Starting async GRPO training with buffer_size={optimal_buffer_size}, max_age={max_trajectory_age_steps} steps" + ) + + print("⏳ Preparing policy generation for training...") + if NEED_REFIT and POLICY_GENERATION_STALE: + print("πŸ”„ Refitting policy generation with actual model weights...") + try: + refit_policy_generation(policy, policy_generation, colocated_inference) + print("βœ… Policy generation refit completed successfully") + POLICY_GENERATION_STALE = False + except Exception as e: + print(f"❌ Policy generation refit failed: {e}") + import traceback + + traceback.print_exc() + return + else: + print("πŸ”„ Preparing policy generation for inference...") + try: + policy_generation.prepare_for_generation() + print("βœ… Policy generation preparation completed successfully") + except Exception as e: + print(f"❌ Policy generation preparation failed: {e}") + import traceback + + traceback.print_exc() + return + + print("βœ… Policy generation setup complete, proceeding to validation...") + + # Run validation at start if configured + if val_at_start and step == 0: + print("\nπŸ” Running initial validation...") + # Pause trajectory collection during initial validation + trajectory_collector.pause.remote() + + try: + val_metrics, validation_timings = validate( + policy_generation, + val_dataloader, + tokenizer, + val_task_to_env, + step=0, + master_config=master_config, + ) + policy_generation.finish_generation() + logger.log_metrics(val_metrics, step, prefix="validation") + logger.log_metrics(validation_timings, step, prefix="timing/validation") + print("βœ… Initial validation completed successfully") + except Exception as e: + print(f"❌ Initial validation failed: {e}") + import traceback + + traceback.print_exc() + # Continue anyway since validation is optional + finally: + # Resume trajectory collection after initial validation + trajectory_collector.resume.remote() + + print("βœ… All setup complete, starting buffer wait...") + + # Wait for initial buffer fill + print( + f"⏳ Waiting for replay buffer to have sufficient trajectories ({min_trajectories_needed} trajectories)..." + ) + wait_iterations = 0 + while True: + buffer_size_current = ray.get(replay_buffer.size.remote()) + + print( + f" Wait iteration {wait_iterations}: buffer_filled_ratio={buffer_size_current}/{min_trajectories_needed}" + ) + + if buffer_size_current >= min_trajectories_needed: + break + + time.sleep(1.0) + + print("βœ… Buffer ready! Starting training loop...") + + # Main training loop + try: + while step < master_config["grpo"]["max_num_steps"]: + print( + f"\n{'=' * 25} Step {step + 1}/{master_config['grpo']['max_num_steps']} {'=' * 25}" + ) + maybe_gpu_profile_step(policy, step + 1) + if policy != policy_generation: + maybe_gpu_profile_step(policy_generation, step + 1) + + with timer.time("total_step_time"): + # Sample trajectories from replay buffer + print("πŸ“¦ Sampling from replay buffer...") + with timer.time("buffer_sampling"): + buffer_size_current = ray.get(replay_buffer.size.remote()) + print( + f"πŸ“Š Step coordination: training_step={step}, max_age={max_trajectory_age_steps}, buffer_size={buffer_size_current}" + ) + + # Sample the required number of per-prompt groups. + num_prompt_groups_needed = master_config["grpo"][ + "num_prompts_per_step" + ] + sample_result = ray.get( + replay_buffer.sample.remote( + num_prompt_groups=num_prompt_groups_needed, + current_weight_version=weight_version, + max_age_steps=max_trajectory_age_steps, + ) + ) + + if ( + sample_result is None + or len(sample_result["trajectories"]) + != num_prompt_groups_needed + ): + print( + "⏳ Buffer empty or not enough groups to form a full step, waiting..." + ) + + # Get buffer debug info to help diagnose the issue + buffer_debug = ray.get(replay_buffer.get_debug_info.remote()) + buffer_size = buffer_debug["total_trajectories"] + + if buffer_size > 0: + print( + f"πŸ” Debug: Buffer has {buffer_size} trajectories but sampling requires exactly {num_prompt_groups_needed}." + ) + print(f" Current weight version: {weight_version}") + print(f" Max trajectory age: {max_trajectory_age_steps}") + print( + f" Trajectory versions in buffer: {buffer_debug['trajectory_versions']}" + ) + + time.sleep(0.5) + continue + + # Extract trajectories and metadata from sample result + trajectories = sample_result["trajectories"] + avg_trajectory_age = sample_result["avg_trajectory_age"] + + print( + f"βœ… Sampled {len(trajectories)} trajectory groups from buffer (avg age: {avg_trajectory_age:.2f} steps)" + ) + + # Concatenate per-prompt groups into a single training batch + per_prompt_batches = [t["batch"] for t in trajectories] + repeated_batch = BatchedDataDict.from_batches(per_prompt_batches) + # Aggregate rollout metrics across groups (simple mean where applicable) + rollout_metrics = {} + for t in trajectories: + for k, v in t["rollout_metrics"].items(): + rollout_metrics.setdefault(k, []).append(v) + rollout_metrics = { + k: (sum(v) / len(v) if isinstance(v[0], (int, float)) else v) + for k, v in rollout_metrics.items() + } + + # Enforce fixed training batch: num_prompts_per_step * num_generations_per_prompt + expected_batch_size = ( + master_config["grpo"]["num_prompts_per_step"] + * master_config["grpo"]["num_generations_per_prompt"] + ) + if repeated_batch.size != expected_batch_size: + print( + f"❌ Unexpected training batch size: got {repeated_batch.size}, expected {expected_batch_size}. Skipping step and waiting for correct buffer content." + ) + time.sleep(0.5) + continue + + # Optional sanity: ensure DP divisibility to avoid sharding issues + dp_size = policy.sharding_annotations.get_axis_size("data_parallel") + if expected_batch_size % dp_size != 0: + raise AssertionError( + f"Configuration error: (num_prompts_per_step * num_generations_per_prompt) = {expected_batch_size} must be divisible by data_parallel size {dp_size}." + ) + + print(f"Got trajectory batch (size: {repeated_batch.size})") + + print("β–Ά Processing rewards...") + with timer.time("reward_calculation"): + prompt_only_message_logs = [] + for message_log in repeated_batch["message_log"]: + prompt_only_log = [] + for message in message_log: + if message["role"] == "user" or message["role"] == "system": + prompt_only_log.append(message) + prompt_only_message_logs.append(prompt_only_log) + + prompt_batched_flat, prompt_input_lengths = ( + batched_message_log_to_flat_message( + prompt_only_message_logs, + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + ) + ) + prompt_only_ids = prompt_batched_flat["token_ids"] + + rewards = repeated_batch["total_reward"] + + print("β–Ά Computing advantages...") + + baseline, std = calculate_baseline_and_std_per_prompt( + prompt_only_ids, + rewards, + torch.ones_like(rewards), + leave_one_out_baseline=master_config["grpo"][ + "use_leave_one_out_baseline" + ], + ) + advantages = (rewards - baseline).unsqueeze(-1) + + print( + f" πŸ“Š Rewards stats: min={rewards.min():.4f}, max={rewards.max():.4f}, mean={rewards.mean():.4f}, std={rewards.std():.4f}" + ) + print( + f" πŸ“Š Baseline stats: min={baseline.min():.4f}, max={baseline.max():.4f}, mean={baseline.mean():.4f}" + ) + print( + f" πŸ“Š Advantages stats: min={advantages.min():.4f}, max={advantages.max():.4f}, mean={advantages.mean():.4f}, std={advantages.std():.4f}" + ) + + if master_config["grpo"]["normalize_rewards"]: + zero_std_mask = std > 0 + advantages[zero_std_mask] = ( + advantages[zero_std_mask] / std.unsqueeze(-1)[zero_std_mask] + ) + print( + f" πŸ“Š Normalized advantages stats: min={advantages.min():.4f}, max={advantages.max():.4f}, mean={advantages.mean():.4f}, std={advantages.std():.4f}" + ) + + # Prepare training data (same as sync version) + with timer.time("data_processing"): + # Add loss mask and advantages to each message + for i, message_log in enumerate(repeated_batch["message_log"]): + for j, message in enumerate(message_log): + if message["role"] == "assistant": + message["token_loss_mask"] = torch.ones_like( + message["token_ids"] + ) + else: + message["token_loss_mask"] = torch.zeros_like( + message["token_ids"] + ) + if "generation_logprobs" not in message: + message["generation_logprobs"] = torch.zeros_like( + message["token_ids"], dtype=torch.float32 + ) + message["advantages"] = advantages[i].expand( + message["token_ids"].shape + ) + + # Convert to flat format for training + flat_messages, input_lengths = batched_message_log_to_flat_message( + repeated_batch["message_log"], + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + make_sequence_length_divisible_by=master_config["policy"][ + "make_sequence_length_divisible_by" + ], + ) + + # Create training data + train_data = BatchedDataDict[ClippedPGLossDataDict]( + { + "input_ids": flat_messages["token_ids"], + "input_lengths": input_lengths, + "advantages": flat_messages["advantages"], + "generation_logprobs": flat_messages["generation_logprobs"], + "token_mask": flat_messages["token_loss_mask"], + "sample_mask": repeated_batch["loss_multiplier"], + } + ) + train_data.to("cpu") + + # Training phase (same as sync version) + print("β–Ά Preparing for logprob inference...") + with timer.time("logprob_inference_prep"): + policy.prepare_for_lp_inference() + + print("β–Ά Computing logprobs...") + with timer.time("policy_and_reference_logprobs"): + fprop_logprobs = policy.get_logprobs(train_data)["logprobs"] + reference_logprobs = policy.get_reference_policy_logprobs( + train_data + )["reference_logprobs"] + train_data["prev_logprobs"] = fprop_logprobs + train_data["reference_policy_logprobs"] = reference_logprobs + + print("β–Ά Preparing for training...") + with timer.time("training_prep"): + policy.prepare_for_training() + POLICY_GENERATION_STALE = True + + print("β–Ά Training policy...") + with timer.time("policy_training"): + train_results = policy.train(train_data, loss_fn) + + print("πŸ”„ Synchronizing policy weights to trajectory collector…") + if NEED_REFIT: + # Measure pending-generation wait as exposed_generation time + print("πŸ”„ Coordinating with trajectory collector before refit...") + with timer.time("exposed_generation"): + ray.get(trajectory_collector.prepare_for_refit.remote()) + + # Only the actual refit/weight transfer should be counted as weight_sync + print("πŸ”„ Performing policy generation refit...") + with timer.time("weight_sync"): + refit_policy_generation( + policy, policy_generation, colocated_inference + ) + POLICY_GENERATION_STALE = False + + # Update weight version before resuming trajectory collection so that all trajectories are updated with the new correct weight version + weight_version += 1 + trajectory_collector.set_weight_version.remote(weight_version) + trajectory_collector.resume_after_refit.remote() + + # Validation + val_metrics, validation_timings = None, None + is_last_step = step + 1 == master_config["grpo"]["max_num_steps"] + + if val_period > 0 and (step + 1) % val_period == 0: + # Pause trajectory collection during validation to reduce memory pressure + trajectory_collector.pause.remote() + + if NEED_REFIT and POLICY_GENERATION_STALE: + refit_policy_generation( + policy, policy_generation, colocated_inference + ) + POLICY_GENERATION_STALE = False + else: + policy_generation.prepare_for_generation() + val_metrics, validation_timings = validate( + policy_generation, + val_dataloader, + tokenizer, + val_task_to_env, + step=step + 1, + master_config=master_config, + ) + policy_generation.finish_generation() + logger.log_metrics( + validation_timings, step + 1, prefix="timing/validation" + ) + logger.log_metrics(val_metrics, step + 1, prefix="validation") + + # Explicit GPU memory cleanup after validation in async mode + import gc + + gc.collect() + torch.cuda.empty_cache() + + # Resume trajectory collection after validation + trajectory_collector.resume.remote() + + # Checkpointing (same as sync version) + consumed_samples += master_config["grpo"]["num_prompts_per_step"] + if master_config["checkpointing"]["enabled"] and ( + is_last_step + or (step + 1) % master_config["checkpointing"]["save_period"] == 0 + ): + policy.prepare_for_training() + + grpo_save_state["current_step"] = step + 1 + if val_metrics is not None: + grpo_save_state["val_reward"] = val_metrics["accuracy"] + elif "val_reward" in grpo_save_state: + del grpo_save_state["val_reward"] + grpo_save_state["consumed_samples"] = consumed_samples + + if master_config["checkpointing"]["metric_name"] is not None: + if ( + master_config["checkpointing"]["metric_name"] + not in grpo_save_state + ): + warnings.warn( + f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. " + "Saving most recent k checkpoints instead." + ) + master_config["checkpointing"]["metric_name"] = None + + with timer.time("checkpointing"): + print(f"Saving checkpoint for step {step + 1}...") + checkpoint_path = checkpointer.init_tmp_checkpoint( + step + 1, grpo_save_state, master_config + ) + policy.save_checkpoint( + weights_path=os.path.join( + checkpoint_path, "policy", "weights" + ), + optimizer_path=os.path.join( + checkpoint_path, "policy", "optimizer" + ), + tokenizer_path=os.path.join( + checkpoint_path, "policy", "tokenizer" + ), + ) + # Get dataloader state from trajectory collector + actual_dataloader_state = ray.get( + trajectory_collector.get_dataloader_state.remote() + ) + torch.save( + actual_dataloader_state, + os.path.join(checkpoint_path, "train_dataloader.pt"), + ) + checkpointer.finalize_checkpoint(checkpoint_path) + policy.offload_after_refit() + + log_data = {"content": flat_messages["content"]} + log_data["rewards"] = rewards.tolist() + log_data["generation_logprobs"] = train_data["generation_logprobs"].tolist() + log_data["prev_logprobs"] = train_data["prev_logprobs"].tolist() + log_data["input_lengths"] = input_lengths.tolist() + logger.log_batched_dict_as_jsonl(log_data, f"train_data_step{step}.jsonl") + + metrics = { + "loss": train_results["loss"].numpy(), + "reward": rewards.numpy(), + "grad_norm": train_results["grad_norm"].numpy(), + } + metrics.update(train_results["all_mb_metrics"]) + for k, v in metrics.items(): + if k in { + "lr", + "wd", + "reward", + "global_valid_seqs", + "global_valid_toks", + }: + metrics[k] = np.mean(v).item() + else: + metrics[k] = np.sum(v).item() + metrics.update(rollout_metrics) + + timing_metrics: dict[str, float] = timer.get_timing_metrics( + reduction_op="sum" + ) + + # Add buffer stats + buffer_size_current = ray.get(replay_buffer.size.remote()) + metrics["buffer_size"] = buffer_size_current + metrics["avg_trajectory_age"] = avg_trajectory_age + + print("\nπŸ“Š Training Results:") + print(f" β€’ Loss: {metrics['loss']:.4f}") + print(f" β€’ Avg Reward: {np.mean(rewards.numpy()):.4f}") + print(f" β€’ Buffer Size: {buffer_size_current}") + print(f" β€’ Avg Trajectory Age: {avg_trajectory_age:.2f} steps") + + print("\n⏱️ Timing:") + total_time = timing_metrics.get("total_step_time", 0) + print(f" β€’ Total step time: {total_time:.2f}s") + for k, v in sorted( + timing_metrics.items(), key=lambda item: item[1], reverse=True + ): + if k != "total_step_time": + percent = (v / total_time * 100) if total_time > 0 else 0 + print(f" β€’ {k}: {v:.2f}s ({percent:.1f}%)") + + logger.log_metrics(metrics, step + 1, prefix="train") + logger.log_metrics(timing_metrics, step + 1, prefix="timing/train") + + timer.reset() + step += 1 + + finally: + # Clean up + print("πŸ›‘ Stopping trajectory collection...") + try: + ray.kill(trajectory_collector) + except Exception as e: + print(f"Error stopping trajectory collector: {e}") + + try: + ray.kill(replay_buffer) + except Exception as e: + print(f"Error stopping replay buffer: {e}") + + print("Async GRPO training complete!") diff --git a/nemo_rl/distributed/ray_actor_environment_registry.py b/nemo_rl/distributed/ray_actor_environment_registry.py index ca506abe0d..7f912648c6 100644 --- a/nemo_rl/distributed/ray_actor_environment_registry.py +++ b/nemo_rl/distributed/ray_actor_environment_registry.py @@ -37,6 +37,10 @@ "nemo_rl.environments.code_environment.CodeEnvironment": PY_EXECUTABLES.SYSTEM, "nemo_rl.environments.reward_model_environment.RewardModelEnvironment": PY_EXECUTABLES.SYSTEM, "nemo_rl.environments.games.sliding_puzzle.SlidingPuzzleEnv": PY_EXECUTABLES.SYSTEM, + # AsyncTrajectoryCollector needs vLLM environment to handle exceptions from VllmGenerationWorker + "nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector": PY_EXECUTABLES.VLLM, + # ReplayBuffer needs vLLM environment to handle trajectory data from VllmGenerationWorker + "nemo_rl.algorithms.async_utils.ReplayBuffer": PY_EXECUTABLES.VLLM, "nemo_rl.environments.tools.retriever.RAGEnvironment": PY_EXECUTABLES.SYSTEM, } diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh index 8c23fc8c7e..ae87fb6a44 100644 --- a/tests/functional/L1_Functional_Tests_GPU.sh +++ b/tests/functional/L1_Functional_Tests_GPU.sh @@ -18,6 +18,7 @@ set -xeuo pipefail # Exit immediately if a command exits with a non-zero status cd /opt/nemo-rl time uv run --no-sync bash ./tests/functional/sft.sh time uv run --no-sync bash ./tests/functional/grpo.sh +time uv run --no-sync bash ./tests/functional/grpo_async.sh time uv run --no-sync bash ./tests/functional/grpo_megatron.sh time uv run --no-sync bash ./tests/functional/grpo_multiturn.sh time uv run --no-sync bash ./tests/functional/grpo_non_colocated.sh diff --git a/tests/functional/grpo_async.sh b/tests/functional/grpo_async.sh new file mode 100644 index 0000000000..1e14266e97 --- /dev/null +++ b/tests/functional/grpo_async.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_grpo_math.py \ + policy.model_name=Qwen/Qwen3-0.6B \ + grpo.num_prompts_per_step=2 \ + grpo.num_generations_per_prompt=4 \ + policy.train_global_batch_size=4 \ + policy.train_micro_batch_size=1 \ + cluster.gpus_per_node=2 \ + grpo.max_num_steps=20 \ + grpo.async_grpo.enabled=true \ + grpo.async_grpo.max_trajectory_age_steps=1 \ + policy.generation.vllm_cfg.async_engine=true \ + loss_fn.use_importance_sampling_correction=true \ + policy.generation.colocated.enabled=false \ + policy.generation.colocated.resources.num_nodes=1 \ + policy.generation.colocated.resources.gpus_per_node=1 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.monitor_gpus=true \ + checkpointing.enabled=false \ + $@ \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'max(data["train/token_mult_prob_error"]) < 1.05' + diff --git a/tests/unit/algorithms/test_async_utils.py b/tests/unit/algorithms/test_async_utils.py new file mode 100644 index 0000000000..c8eb1639bd --- /dev/null +++ b/tests/unit/algorithms/test_async_utils.py @@ -0,0 +1,700 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import threading +import unittest.mock as mock + +import pytest +import ray +import torch + +# Set up Ray temp directory before any Ray operations +# Try multiple approaches to ensure Ray uses a writable directory +_temp_dir = tempfile.mkdtemp(prefix="ray_async_test_") +os.environ["RAY_TEMP_DIR"] = _temp_dir +os.environ["RAY_TMPDIR"] = _temp_dir # Alternative env var +os.environ["TMPDIR"] = _temp_dir # System temp dir + +from nemo_rl.algorithms.async_utils import AsyncTrajectoryCollector, ReplayBuffer +from nemo_rl.algorithms.grpo import MasterConfig +from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.environments.interfaces import ( + EnvironmentInterface, + EnvironmentReturn, +) + + +@ray.remote(num_cpus=0) +class MockEnvironment(EnvironmentInterface): + """Mock environment for testing async utilities.""" + + def __init__(self, rewards: list[float]): + self.rewards = rewards + self._calls = 0 + + def step( + self, messages: list[LLMMessageLogType], env_info: list[dict] + ) -> EnvironmentReturn: + self._calls += 1 + return ( + [{"role": "environment", "content": "observation"}] * len(messages), + [{}] * len(messages), + [[]] * len(messages), + self.rewards, + [True] * len(messages), + [None] * len(messages), + ) + + def get_calls(self): + return self._calls + + def reset_calls(self): + self._calls = 0 + return True + + def global_post_process_and_metrics( + self, batch: BatchedDataDict + ) -> tuple[BatchedDataDict, dict]: + return batch, {} + + +class MockGenerationInterface: + """Mock generation interface for testing.""" + + def __init__(self): + self.prepare_calls = 0 + self.finish_calls = 0 + + def prepare_for_generation(self, **kwargs): + self.prepare_calls += 1 + + def finish_generation(self): + self.finish_calls += 1 + + +class TestReplayBuffer: + """Test cases for ReplayBuffer.""" + + def test_replay_buffer_initialization(self): + """Test ReplayBuffer initialization.""" + buffer = ReplayBuffer.remote(max_size=10) + size = ray.get(buffer.size.remote()) + assert size == 0 + + debug_info = ray.get(buffer.get_debug_info.remote()) + assert debug_info["total_trajectories"] == 0 + assert debug_info["max_size"] == 10 + assert debug_info["trajectory_versions"] == [] + assert debug_info["target_weight_versions"] == [] + + ray.kill(buffer) + + def test_replay_buffer_push_and_size(self): + """Test pushing trajectories to buffer.""" + buffer = ReplayBuffer.remote(max_size=3) + + # Create mock trajectories + trajectory1 = {"batch": {"data": "test1"}, "rollout_metrics": {"reward": 1.0}} + trajectory2 = {"batch": {"data": "test2"}, "rollout_metrics": {"reward": 2.0}} + + # Push trajectories + status1 = ray.get( + buffer.push_with_wait_signal.remote( + trajectory1, weight_version=0, target_weight_version=1 + ) + ) + assert status1 == "success" + + status2 = ray.get( + buffer.push_with_wait_signal.remote( + trajectory2, weight_version=1, target_weight_version=2 + ) + ) + assert status2 == "success" + + # Check size + size = ray.get(buffer.size.remote()) + assert size == 2 + + # Check debug info + debug_info = ray.get(buffer.get_debug_info.remote()) + assert debug_info["total_trajectories"] == 2 + assert debug_info["trajectory_versions"] == [0, 1] + assert debug_info["target_weight_versions"] == [1, 2] + + ray.kill(buffer) + + def test_replay_buffer_max_size_limit(self): + """Test that buffer respects max size limit.""" + buffer = ReplayBuffer.remote(max_size=2) + + # Fill buffer to capacity + trajectory1 = {"batch": {"data": "test1"}, "rollout_metrics": {"reward": 1.0}} + trajectory2 = {"batch": {"data": "test2"}, "rollout_metrics": {"reward": 2.0}} + trajectory3 = {"batch": {"data": "test3"}, "rollout_metrics": {"reward": 3.0}} + + # Push first two trajectories + status1 = ray.get( + buffer.push_with_wait_signal.remote( + trajectory1, weight_version=0, target_weight_version=1 + ) + ) + status2 = ray.get( + buffer.push_with_wait_signal.remote( + trajectory2, weight_version=1, target_weight_version=2 + ) + ) + assert status1 == "success" + assert status2 == "success" + + # Try to push third trajectory (should return "full") + status3 = ray.get( + buffer.push_with_wait_signal.remote( + trajectory3, weight_version=2, target_weight_version=3 + ) + ) + assert status3 == "full" + + # Size should still be 2 + size = ray.get(buffer.size.remote()) + assert size == 2 + + ray.kill(buffer) + + def test_replay_buffer_sampling_basic(self): + """Test basic trajectory sampling.""" + buffer = ReplayBuffer.remote(max_size=10) + + # Push trajectories with different weight versions + trajectories = [] + for i in range(3): + trajectory = { + "batch": {"data": f"test{i}"}, + "rollout_metrics": {"reward": float(i)}, + } + trajectories.append(trajectory) + ray.get( + buffer.push_with_wait_signal.remote( + trajectory, weight_version=i, target_weight_version=i + 1 + ) + ) + + # Sample trajectories intended for current step 2 + sample_result = ray.get( + buffer.sample.remote( + num_prompt_groups=1, + current_weight_version=2, + max_age_steps=2, + ) + ) + + assert sample_result is not None + assert len(sample_result["trajectories"]) == 1 + assert "avg_trajectory_age" in sample_result + + # The trajectory should be intended for step 2 (target_weight_version=2) + # But we pushed with target_weight_version=i+1, so trajectory at i=1 has target=2 + sampled_trajectory = sample_result["trajectories"][0] + assert sampled_trajectory["batch"]["data"] == "test1" + + ray.kill(buffer) + + def test_replay_buffer_sampling_insufficient_trajectories(self): + """Test sampling when insufficient trajectories are available.""" + buffer = ReplayBuffer.remote(max_size=10) + + # Push only one trajectory + trajectory = {"batch": {"data": "test"}, "rollout_metrics": {"reward": 1.0}} + ray.get( + buffer.push_with_wait_signal.remote( + trajectory, weight_version=0, target_weight_version=1 + ) + ) + + # Try to sample more trajectories than available for current step + sample_result = ray.get( + buffer.sample.remote( + num_prompt_groups=2, # Request 2 but only 1 available + current_weight_version=1, + max_age_steps=1, + ) + ) + + assert sample_result is None # Should return None when insufficient + + ray.kill(buffer) + + def test_replay_buffer_age_filtering(self): + """Test that old trajectories are filtered out.""" + buffer = ReplayBuffer.remote(max_size=10) + + # Push trajectories with different ages + old_trajectory = {"batch": {"data": "old"}, "rollout_metrics": {"reward": 1.0}} + recent_trajectory = { + "batch": {"data": "recent"}, + "rollout_metrics": {"reward": 2.0}, + } + + ray.get( + buffer.push_with_wait_signal.remote( + old_trajectory, weight_version=0, target_weight_version=1 + ) + ) + ray.get( + buffer.push_with_wait_signal.remote( + recent_trajectory, weight_version=2, target_weight_version=3 + ) + ) + + # Sample with current_weight_version=3 and max_age_steps=1 + # This should filter out the trajectory with weight_version=0 (too old) + with pytest.raises( + ValueError, match="Found .* trajectories older than min_valid_version" + ): + ray.get( + buffer.sample.remote( + num_prompt_groups=1, + current_weight_version=3, + max_age_steps=1, + ) + ) + + ray.kill(buffer) + + def test_replay_buffer_target_weight_matching(self): + """Test that sampling only returns trajectories intended for current step.""" + buffer = ReplayBuffer.remote(max_size=10) + + # Push trajectories intended for different target steps + trajectory1 = { + "batch": {"data": "for_step_1"}, + "rollout_metrics": {"reward": 1.0}, + } + trajectory2 = { + "batch": {"data": "for_step_2"}, + "rollout_metrics": {"reward": 2.0}, + } + + ray.get( + buffer.push_with_wait_signal.remote( + trajectory1, weight_version=0, target_weight_version=1 + ) + ) + ray.get( + buffer.push_with_wait_signal.remote( + trajectory2, weight_version=1, target_weight_version=2 + ) + ) + + # Sample for current step 1 - should only get trajectory intended for step 1 + sample_result = ray.get( + buffer.sample.remote( + num_prompt_groups=1, + current_weight_version=1, + max_age_steps=2, + ) + ) + + assert sample_result is not None + assert len(sample_result["trajectories"]) == 1 + assert sample_result["trajectories"][0]["batch"]["data"] == "for_step_1" + + ray.kill(buffer) + + def test_replay_buffer_get_existing_target_weights(self): + """Test getting existing target weight versions.""" + buffer = ReplayBuffer.remote(max_size=10) + + # Initially empty + existing_weights = ray.get(buffer.get_existing_target_weights.remote()) + assert existing_weights == set() + + # Push trajectories with different target weights + trajectory1 = {"batch": {"data": "test1"}, "rollout_metrics": {"reward": 1.0}} + trajectory2 = {"batch": {"data": "test2"}, "rollout_metrics": {"reward": 2.0}} + + ray.get( + buffer.push_with_wait_signal.remote( + trajectory1, weight_version=0, target_weight_version=1 + ) + ) + ray.get( + buffer.push_with_wait_signal.remote( + trajectory2, weight_version=1, target_weight_version=3 + ) + ) + + existing_weights = ray.get(buffer.get_existing_target_weights.remote()) + assert existing_weights == {1, 3} + + ray.kill(buffer) + + def test_replay_buffer_clear(self): + """Test clearing the buffer.""" + buffer = ReplayBuffer.remote(max_size=10) + + # Push some trajectories + trajectory = {"batch": {"data": "test"}, "rollout_metrics": {"reward": 1.0}} + ray.get( + buffer.push_with_wait_signal.remote( + trajectory, weight_version=0, target_weight_version=1 + ) + ) + + # Verify buffer has content + size = ray.get(buffer.size.remote()) + assert size == 1 + + # Clear buffer + ray.get(buffer.clear.remote()) + + # Verify buffer is empty + size = ray.get(buffer.size.remote()) + assert size == 0 + + debug_info = ray.get(buffer.get_debug_info.remote()) + assert debug_info["total_trajectories"] == 0 + assert debug_info["trajectory_versions"] == [] + assert debug_info["target_weight_versions"] == [] + + ray.kill(buffer) + + +class TestAsyncTrajectoryCollector: + """Test cases for AsyncTrajectoryCollector.""" + + def create_mock_config(self) -> MasterConfig: + """Create a mock master config for testing.""" + return { + "grpo": { + "num_prompts_per_step": 2, + "num_generations_per_prompt": 3, + "max_rollout_turns": 1, + "async_grpo": {"max_trajectory_age_steps": 2}, + }, + "policy": {"max_total_sequence_length": 512}, + } + + def create_mock_batch(self, size: int = 2) -> BatchedDataDict[DatumSpec]: + """Create a mock batch for testing.""" + message_logs = [] + for i in range(size): + message_logs.append( + [ + {"role": "user", "content": f"Test prompt {i}"}, + ] + ) + + return BatchedDataDict[DatumSpec]( + { + "task_name": ["test"] * size, + "message_log": message_logs, + "extra_env_info": [{}] * size, + "loss_multiplier": torch.ones(size), + } + ) + + def test_async_trajectory_collector_initialization(self): + """Test AsyncTrajectoryCollector initialization.""" + buffer = ReplayBuffer.remote(max_size=10) + mock_generation = MockGenerationInterface() + mock_tokenizer = mock.MagicMock() + mock_env = MockEnvironment.remote(rewards=[1.0, 2.0]) + task_to_env = {"test": mock_env} + master_config = self.create_mock_config() + + collector = AsyncTrajectoryCollector.remote( + policy_generation=mock_generation, + tokenizer=mock_tokenizer, + task_to_env=task_to_env, + master_config=master_config, + replay_buffer=buffer, + start_step=0, + ) + + # Test basic functionality + weight_version = ray.get(collector.get_weight_version.remote()) + assert weight_version == 0 + + ray.kill(collector) + ray.kill(buffer) + ray.kill(mock_env) + + def test_async_trajectory_collector_weight_version_updates(self): + """Test weight version updates in trajectory collector.""" + buffer = ReplayBuffer.remote(max_size=10) + mock_generation = MockGenerationInterface() + mock_tokenizer = mock.MagicMock() + mock_env = MockEnvironment.remote(rewards=[1.0, 2.0]) + task_to_env = {"test": mock_env} + master_config = self.create_mock_config() + + collector = AsyncTrajectoryCollector.remote( + policy_generation=mock_generation, + tokenizer=mock_tokenizer, + task_to_env=task_to_env, + master_config=master_config, + replay_buffer=buffer, + start_step=0, + ) + + # Update weight version + ray.get(collector.set_weight_version.remote(5)) + weight_version = ray.get(collector.get_weight_version.remote()) + assert weight_version == 5 + + ray.kill(collector) + ray.kill(buffer) + ray.kill(mock_env) + + def test_async_trajectory_collector_pause_resume(self): + """Test pause and resume functionality.""" + buffer = ReplayBuffer.remote(max_size=10) + mock_generation = MockGenerationInterface() + mock_tokenizer = mock.MagicMock() + mock_env = MockEnvironment.remote(rewards=[1.0, 2.0]) + task_to_env = {"test": mock_env} + master_config = self.create_mock_config() + + collector = AsyncTrajectoryCollector.remote( + policy_generation=mock_generation, + tokenizer=mock_tokenizer, + task_to_env=task_to_env, + master_config=master_config, + replay_buffer=buffer, + start_step=0, + ) + + # Test pause and resume (these should not raise errors) + ray.get(collector.pause.remote()) + ray.get(collector.resume.remote()) + + ray.kill(collector) + ray.kill(buffer) + ray.kill(mock_env) + + def test_async_trajectory_collector_prepare_for_refit(self): + """Test prepare for refit functionality.""" + buffer = ReplayBuffer.remote(max_size=10) + mock_generation = MockGenerationInterface() + mock_tokenizer = mock.MagicMock() + mock_env = MockEnvironment.remote(rewards=[1.0, 2.0]) + task_to_env = {"test": mock_env} + master_config = self.create_mock_config() + + collector = AsyncTrajectoryCollector.remote( + policy_generation=mock_generation, + tokenizer=mock_tokenizer, + task_to_env=task_to_env, + master_config=master_config, + replay_buffer=buffer, + start_step=0, + ) + + # Test prepare for refit (should complete without hanging) + ray.get(collector.prepare_for_refit.remote()) + ray.get(collector.resume_after_refit.remote()) + + ray.kill(collector) + ray.kill(buffer) + ray.kill(mock_env) + + def test_calculate_target_weights(self): + """Test target weight calculation logic.""" + buffer = ReplayBuffer.remote(max_size=10) + mock_generation = MockGenerationInterface() + mock_tokenizer = mock.MagicMock() + mock_env = MockEnvironment.remote(rewards=[1.0, 2.0]) + task_to_env = {"test": mock_env} + master_config = self.create_mock_config() + + collector = AsyncTrajectoryCollector.remote( + policy_generation=mock_generation, + tokenizer=mock_tokenizer, + task_to_env=task_to_env, + master_config=master_config, + replay_buffer=buffer, + start_step=0, + ) + + # Test target weight calculation with different scenarios + # Note: We can't directly test the private method, but we can test its effects + # through the public interface behavior + + ray.kill(collector) + ray.kill(buffer) + ray.kill(mock_env) + + def test_dataloader_state_retrieval(self): + """Test getting dataloader state for checkpointing.""" + buffer = ReplayBuffer.remote(max_size=10) + mock_generation = MockGenerationInterface() + mock_tokenizer = mock.MagicMock() + mock_env = MockEnvironment.remote(rewards=[1.0, 2.0]) + task_to_env = {"test": mock_env} + master_config = self.create_mock_config() + + collector = AsyncTrajectoryCollector.remote( + policy_generation=mock_generation, + tokenizer=mock_tokenizer, + task_to_env=task_to_env, + master_config=master_config, + replay_buffer=buffer, + start_step=0, + ) + + # Test getting dataloader state (should return empty dict when no dataloader) + state = ray.get(collector.get_dataloader_state.remote()) + assert isinstance(state, dict) + + ray.kill(collector) + ray.kill(buffer) + ray.kill(mock_env) + + +class TestAsyncUtilsIntegration: + """Integration tests for async utilities working together.""" + + def create_mock_config(self) -> MasterConfig: + """Create a mock master config for testing.""" + return { + "grpo": { + "num_prompts_per_step": 2, + "num_generations_per_prompt": 2, + "max_rollout_turns": 1, + "async_grpo": {"max_trajectory_age_steps": 1}, + }, + "policy": {"max_total_sequence_length": 512}, + } + + def create_mock_batch(self, size: int = 2) -> BatchedDataDict[DatumSpec]: + """Create a mock batch for testing.""" + message_logs = [] + for i in range(size): + message_logs.append( + [ + {"role": "user", "content": f"Test prompt {i}"}, + ] + ) + + return BatchedDataDict[DatumSpec]( + { + "task_name": ["test"] * size, + "message_log": message_logs, + "extra_env_info": [{}] * size, + "loss_multiplier": torch.ones(size), + } + ) + + def test_buffer_and_collector_integration(self): + """Test that buffer and collector work together correctly.""" + buffer = ReplayBuffer.remote(max_size=10) + mock_generation = MockGenerationInterface() + mock_tokenizer = mock.MagicMock() + mock_env = MockEnvironment.remote(rewards=[1.0, 2.0]) + task_to_env = {"test": mock_env} + master_config = self.create_mock_config() + + collector = AsyncTrajectoryCollector.remote( + policy_generation=mock_generation, + tokenizer=mock_tokenizer, + task_to_env=task_to_env, + master_config=master_config, + replay_buffer=buffer, + start_step=0, + ) + + # Verify initial state + buffer_size = ray.get(buffer.size.remote()) + assert buffer_size == 0 + + weight_version = ray.get(collector.get_weight_version.remote()) + assert weight_version == 0 + + # Test weight version synchronization + ray.get(collector.set_weight_version.remote(3)) + updated_version = ray.get(collector.get_weight_version.remote()) + assert updated_version == 3 + + ray.kill(collector) + ray.kill(buffer) + ray.kill(mock_env) + + def test_concurrent_operations(self): + """Test that concurrent operations don't cause race conditions.""" + buffer = ReplayBuffer.remote(max_size=5) + + # Push trajectories concurrently from multiple threads + def push_trajectory(buffer, trajectory_id): + trajectory = { + "batch": {"data": f"test{trajectory_id}"}, + "rollout_metrics": {"reward": float(trajectory_id)}, + } + return ray.get( + buffer.push_with_wait_signal.remote( + trajectory, + weight_version=trajectory_id, + target_weight_version=trajectory_id + 1, + ) + ) + + # Use threading to simulate concurrent pushes + threads = [] + results = [] + + def worker(traj_id): + result = push_trajectory(buffer, traj_id) + results.append(result) + + for i in range(3): + thread = threading.Thread(target=worker, args=(i,)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + # All pushes should succeed + assert all(result == "success" for result in results) + + # Buffer should have correct size + final_size = ray.get(buffer.size.remote()) + assert final_size == 3 + + ray.kill(buffer) + + def test_error_handling(self): + """Test error handling in async utilities.""" + # Test with invalid buffer size + with pytest.raises(Exception): + buffer = ReplayBuffer.remote(max_size=-1) + ray.get(buffer.size.remote()) + + # Test buffer operations + buffer = ReplayBuffer.remote(max_size=1) + + # Test sampling from empty buffer + sample_result = ray.get( + buffer.sample.remote( + num_prompt_groups=1, + current_weight_version=0, + max_age_steps=1, + ) + ) + assert sample_result is None + + ray.kill(buffer)