Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/guides/async-grpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ grpo:
async_grpo:
enabled: true
max_trajectory_age_steps: 1 # Maximum age, in training steps, for trajectories
in_flight_weight_updates: false # Enable for faster weight synchronization
recompute_kv_cache_after_weight_updates: false # Invalidates kv cache after in-flight-weight-updates
```

### Complete Example Config
Expand All @@ -65,6 +67,8 @@ grpo:
async_grpo:
enabled: true
max_trajectory_age_steps: 1
in_flight_weight_updates: false # Enable for faster weight synchronization
recompute_kv_cache_after_weight_updates: false # Invalidates kv cache after in-flight-weight-updates

cluster:
num_nodes: 2
Expand Down Expand Up @@ -158,6 +162,11 @@ sequenceDiagram

3. **Resource Allocation**: Ensure sufficient GPU memory for both the training and generation clusters

4. **In-Flight Weight Updates**: Enable `in_flight_weight_updates: true` when using `async_engine: true` for updating the weights of vLLM engine during generation. This prevents stalling training pipeline until longest generation finishes and provides significant performance benefits.

5. **Recompute KV Cache After Weight Updates**: While using in-flight weight update, user can choose whether to recompute
KV caches after weight udpate by configuring `recompute_kv_cache_after_weight_update` configuration.

Comment on lines +165 to +169
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix typo and config key name to match code

  • “udpate” → “update”
  • Use recompute_kv_cache_after_weight_updates (plural) consistently.
-4. **In-Flight Weight Updates**: Enable `in_flight_weight_updates: true` when using `async_engine: true` for updating the weights of vLLM engine during generation. This prevents stalling training pipeline until longest generation finishes and provides significant performance benefits.
+4. **In-Flight Weight Updates**: Enable `in_flight_weight_updates: true` when using `async_engine: true` to update vLLM engine weights during generation. This prevents stalling the training pipeline and provides significant performance benefits.
 
-5. **Recompute KV Cache After Weight Updates**: While using in-flight weight update, user can choose whether to recompute
-KV caches after weight udpate by configuring `recompute_kv_cache_after_weight_update` configuration.
+5. **Recompute KV Cache After Weight Updates**: When using in‑flight weight updates, you can choose to recompute
+KV caches after the weight update by setting `recompute_kv_cache_after_weight_updates: true`.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
4. **In-Flight Weight Updates**: Enable `in_flight_weight_updates: true` when using `async_engine: true` for updating the weights of vLLM engine during generation. This prevents stalling training pipeline until longest generation finishes and provides significant performance benefits.
5. **Recompute KV Cache After Weight Updates**: While using in-flight weight update, user can choose whether to recompute
KV caches after weight udpate by configuring `recompute_kv_cache_after_weight_update` configuration.
4. **In-Flight Weight Updates**: Enable `in_flight_weight_updates: true` when using `async_engine: true` to update vLLM engine weights during generation. This prevents stalling the training pipeline and provides significant performance benefits.
5. **Recompute KV Cache After Weight Updates**: When using inflight weight updates, you can choose to recompute
KV caches after the weight update by setting `recompute_kv_cache_after_weight_updates: true`.
🧰 Tools
🪛 LanguageTool

[grammar] ~168-~168: Ensure spelling is correct
Context: ...her to recompute KV caches after weight udpate by configuring `recompute_kv_cache_afte...

(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)

🤖 Prompt for AI Agents
In docs/guides/async-grpo.md around lines 165 to 169, fix the typos and the
config key name: change “udpate” to “update” and replace occurrences of
recompute_kv_cache_after_weight_update with
recompute_kv_cache_after_weight_updates (plural) so the text matches the actual
configuration key; update both the explanation sentence and any inline
code/config examples in these lines to use the corrected spelling and pluralized
config key.

## Why Importance Sampling Correction Is Required for Async

### The GRPO Objective
Expand Down
2 changes: 2 additions & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ grpo:
enabled: false # Set to true to enable async training mode
# Max age (in training steps) for trajectories used in training
max_trajectory_age_steps: 1
in_flight_weight_updates: false # Set to true to enable in-flight weight updates
recompute_kv_cache_after_weight_updates: false # Set to true to recompute kv cache after in-flight-weight-updates

loss_fn:
reference_policy_kl_penalty: 0.01
Expand Down
84 changes: 70 additions & 14 deletions nemo_rl/algorithms/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, max_size: int):
if max_size <= 0:
raise ValueError(f"max_size must be positive, got {max_size}")
self.max_size = max_size
self.trajectories = []
self.trajectories = [] # List[dict[str, Any]]
# If trajectory_version is 1 and target_weight_version is 4 it means that weight version 1 was used for generating a trajectory and this trajectory will be used for training when weight version is 4.
self.trajectory_versions = [] # it is the weight-version used for generation of a trajectory
self.target_weight_versions = [] # it is the weight-version of the trainer where this trajectory will be used.
Expand Down Expand Up @@ -278,8 +278,12 @@ def __init__(
self._inflight_threads: set[_threading.Thread] = set()
self._threads_lock: _threading.Lock = _threading.Lock()

# Limit in-flight generator requests to num_prompts_per_step
max_inflight = int(self.master_config["grpo"]["num_prompts_per_step"]) or 1
# Limit in-flight generator requests to num_prompts_per_step * max_trajectory_age_steps
# This value limits the parallelism of the generation requests.
max_inflight = (
int(self.master_config["grpo"]["num_prompts_per_step"])
* int(self.master_config["grpo"]["async_grpo"]["max_trajectory_age_steps"])
) or 1
self._inflight_sema = _threading.Semaphore(max_inflight)

# Simple lock to prevent race conditions when checking/spawning workers
Expand Down Expand Up @@ -475,6 +479,9 @@ def _process_batch(self, batch: BatchedDataDict[DatumSpec]) -> None:
print(
f"⏸️ Waiting for refit to complete before starting new generation ({active_threads} threads still active)"
)
print(
" Note: With vLLM V1 async engine, active threads can complete during weight update"
)
self._refit_pause_cleared.wait()

# After refit finishes if weight version has updated, reflect that in the new trajectories
Expand Down Expand Up @@ -520,28 +527,77 @@ def resume(self) -> None:
print("Trajectory collection resumed")

def prepare_for_refit(self) -> None:
"""Pause new generation starts and wait for pending generations to complete before refit."""
"""Pause new generation starts and optionally wait for pending generations.

For vLLM V1 async engine, leverages in-flight weight updates via collective_rpc,
allowing ongoing generations to continue with their current KV caches while
weights are updated. This significantly improves async performance.

For non-async engines, waits for all pending generations to complete before refit.
"""
start_time = time.time()
print("🔄 Preparing for refit: pausing new generations...")

# Pause new generation starts
self._refit_pause_cleared.clear()
print("⏸️ New generation starts paused")

# Wait for all pending generations to complete
# Note that is suboptimal for async performance and will be fixed in a follow-up PR where two more options will be added:
# 1. Pause the generations at their current decoding step, update the weights and continue with decoding.
# 2. Stop the current generations, store in a buffer and resume them in next iteration with new weights.
self.wait_for_pending_generations()
# Check if we're using vLLM async engine
vllm_cfg = (
self.master_config.get("policy", {})
.get("generation", {})
.get("vllm_cfg", {})
)
is_async_engine = vllm_cfg.get("async_engine", False)
in_flight_weight_updates = (
self.master_config.get("grpo", {})
.get("async_grpo", {})
.get("in_flight_weight_updates", False)
)

if is_async_engine and in_flight_weight_updates:
# vLLM V1 async engine supports in-flight weight updates
# Ongoing generations will continue with their current KV caches
# New generations (after weight update) will use the updated weights
print(
"🚀 Using vLLM V1 in-flight weight update - skipping wait for pending generations"
)
print(
f" {len(self._inflight_threads)} ongoing generations will complete with current weights"
)
else:
# For non-async engines, wait for all pending generations to complete
print(
"⏸️ Non-async engine: waiting for all pending generations to complete..."
)
self.wait_for_pending_generations()

elapsed = time.time() - start_time
print(
f"✅ All pending generations completed, ready for refit (took {elapsed:.2f}s)"
)
print(f"✅ Ready for refit (took {elapsed:.2f}s)")

def resume_after_refit(self) -> None:
"""Resume new generation starts after refit is complete."""
print("🔄 Resuming generation starts after refit")

# Invalidate&recompute vLLM caches after the in-flight weight updates if
# recompute_kv_cache_after_weight_updates is True (AREAL-style implementation).
# Otherwise, keep using the stale KV caches (Magistral-style implementation).
async_cfg = self.master_config.get("grpo", {}).get("async_grpo", {})
if async_cfg.get("in_flight_weight_updates", False) and async_cfg.get(
"recompute_kv_cache_after_weight_updates", False
):
try:
print("🔄 Invalidating vLLM prefix/KV caches after weight update")
invalidated = self.policy_generation.invalidate_kv_cache()
if invalidated:
print("✅ Invalidated vLLM prefix/KV caches after weight update")
else:
print(
"⚠️ vLLM cache invalidation reported partial/unsuccessful on some workers"
)
except Exception as e:
print(f"⚠️ Failed to invalidate vLLM caches: {e}")

self._refit_pause_cleared.set()

def wait_for_pending_generations(self) -> None:
Expand Down Expand Up @@ -636,8 +692,8 @@ def _run_prompt_group_worker(
)
break
elif status == "full":
# Exponential backoff up to 1 second
time.sleep(min(backoff_delay, 1.0))
# Exponential backoff up to 0.5 second
time.sleep(min(backoff_delay, 0.5))
backoff_delay *= 1.5
else:
# Unexpected status, wait briefly
Expand Down
18 changes: 17 additions & 1 deletion nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ class AsyncGRPOConfig(TypedDict):
# async replay buffer. Trajectories older than this are excluded during
# sampling; buffer sizing also scales with this value.
max_trajectory_age_steps: int
# Does the weight synchronization as soon as the training is done
# without waiting for the pending generations to finish.
in_flight_weight_updates: NotRequired[bool]
# Recomputes the KV cache after the in-flight weight updates.
recompute_kv_cache_after_weight_updates: NotRequired[bool]


class GRPOConfig(TypedDict):
Expand Down Expand Up @@ -1512,6 +1517,16 @@ def async_grpo_train(
assert master_config["loss_fn"]["use_importance_sampling_correction"] is True, (
"Importance sampling correction must be enabled for async GRPO for good convergence due to off-policy samples!"
)

if master_config["grpo"]["async_grpo"]["max_trajectory_age_steps"] > 1:
if not master_config["grpo"]["async_grpo"].get(
"in_flight_weight_updates", False
):
print(
"⚠️ WARNING: In-flight weight updates must be enabled for async GRPO with max_trajectory_age_steps > 1. "
"Without in-flight weight updates, having more max_trajectory_age_steps will not give any performance benefit."
)

# Import async utilities only when needed
from nemo_rl.algorithms.async_utils import AsyncTrajectoryCollector, ReplayBuffer

Expand Down Expand Up @@ -1725,7 +1740,7 @@ def async_grpo_train(
with timer.time("total_step_time"):
# Sample trajectories from replay buffer
print("📦 Sampling from replay buffer...")
with timer.time("buffer_sampling"):
with timer.time("exposed_generation"):
buffer_size_current = ray.get(replay_buffer.size.remote())
print(
f"📊 Step coordination: training_step={step}, max_age={max_trajectory_age_steps}, buffer_size={buffer_size_current}"
Expand Down Expand Up @@ -2078,6 +2093,7 @@ def async_grpo_train(
tokenizer_path=os.path.join(
checkpoint_path, "policy", "tokenizer"
),
checkpointing_cfg=master_config["checkpointing"],
)
# Get dataloader state from trajectory collector
actual_dataloader_state = ray.get(
Expand Down
5 changes: 5 additions & 0 deletions nemo_rl/models/generation/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,8 @@ def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]:
def update_weights_from_collective(self) -> list[ray.ObjectRef]:
"""Update the model weights from collective communication."""
raise NotImplementedError

# Optional hook; backends may override to invalidate any reusable caches
# (e.g., vLLM prefix/KV caches) after weight updates.
def invalidate_kv_cache(self) -> bool:
return False
22 changes: 22 additions & 0 deletions nemo_rl/models/generation/vllm/vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,3 +827,25 @@ def __del__(self) -> None:
user calls shutdown().
"""
self.shutdown()

def invalidate_kv_cache(self) -> bool:
"""Invalidate reusable caches in vLLM (e.g., prefix/KV cache) after weight updates.

For async_engine, calls reset_prefix_cache_async on workers. For sync, calls reset_prefix_cache.
Returns True if all workers report success.
"""
try:
method_name = (
"reset_prefix_cache_async"
if self.cfg["vllm_cfg"]["async_engine"]
else "reset_prefix_cache"
)
futures = self.worker_group.run_all_workers_single_data(
method_name,
run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"],
)
results = ray.get(futures)
return all(result for result in results if result is not None)
except Exception as e:
print(f"Error invalidating vLLM caches: {e}")
return False
4 changes: 4 additions & 0 deletions nemo_rl/models/policy/lm_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,10 @@ def finish_generation(self, *args: Any, **kwargs: Any) -> bool:
# We don't need to do anything here
return True

def invalidate_kv_cache(self, *args: Any, **kwargs: Any) -> bool:
# We don't need to do anything here
return True

Comment on lines +667 to +670
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

**Remove unused *args/kwargs; align signature with interface

The interface method has no extra args; dropping them fixes Ruff ARG002 and avoids API drift.

-    def invalidate_kv_cache(self, *args: Any, **kwargs: Any) -> bool:
-        # We don't need to do anything here
-        return True
+    def invalidate_kv_cache(self) -> bool:
+        # We don't need to do anything here
+        return True

As per coding guidelines

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def invalidate_kv_cache(self, *args: Any, **kwargs: Any) -> bool:
# We don't need to do anything here
return True
def invalidate_kv_cache(self) -> bool:
# We don't need to do anything here
return True
🧰 Tools
🪛 Ruff (0.14.2)

667-667: Unused method argument: args

(ARG002)


667-667: Unused method argument: kwargs

(ARG002)

🤖 Prompt for AI Agents
In nemo_rl/models/policy/lm_policy.py around lines 667 to 670, the
invalidate_kv_cache method currently accepts unused *args and **kwargs which
triggers Ruff ARG002 and diverges from the interface; change the method
signature to def invalidate_kv_cache(self) -> bool: and remove the unused
parameters from the method body (leave the existing return True), ensuring the
method matches the interface and no callers rely on extra args.

def finish_training(self, *args: Any, **kwargs: Any) -> None:
# Placeholder implementation
pass
Expand Down
14 changes: 13 additions & 1 deletion nemo_rl/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,19 @@ def log_metrics(
for name, value in metrics.items():
if prefix:
name = f"{prefix}/{name}"
self.writer.add_scalar(name, value, step)

# Skip non-scalar values that TensorBoard can't handle
if isinstance(value, (dict, list)):
print(
f"Warning: Skipping non-scalar metric '{name}' for TensorBoard logging (type: {type(value).__name__})"
)
continue

try:
self.writer.add_scalar(name, value, step)
except Exception as e:
print(f"Warning: Failed to log metric '{name}' to TensorBoard: {e}")
continue

def log_hyperparams(self, params: Mapping[str, Any]) -> None:
"""Log hyperparameters to Tensorboard.
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/algorithms/test_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,10 @@ def val_iter(self):
"reward_scaling": {"enabled": False},
"reward_shaping": {"enabled": False},
"use_dynamic_sampling": False,
"async_grpo": {
"enabled": False,
"max_trajectory_age_steps": 1,
},
},
"policy": {
"train_global_batch_size": 1,
Expand Down
Loading