From 6f041f56b4940fdd9f816b516751444193f91084 Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Sun, 2 Nov 2025 22:25:58 -0800 Subject: [PATCH] feat: add capability to update weights inflight during generation (#1381) Signed-off-by: Parth Chadha Signed-off-by: Youngeun Kwon Signed-off-by: Terry Kong Co-authored-by: Youngeun Kwon Co-authored-by: Terry Kong Co-authored-by: Terry Kong Signed-off-by: NeMo Bot --- docs/guides/async-grpo.md | 9 ++ examples/configs/grpo_math_1B.yaml | 2 + nemo_rl/algorithms/async_utils.py | 84 +++++++++++++++---- nemo_rl/algorithms/grpo.py | 18 +++- nemo_rl/models/generation/interfaces.py | 5 ++ .../models/generation/vllm/vllm_generation.py | 22 +++++ nemo_rl/models/policy/lm_policy.py | 4 + nemo_rl/utils/logger.py | 14 +++- tests/unit/algorithms/test_grpo.py | 4 + 9 files changed, 146 insertions(+), 16 deletions(-) diff --git a/docs/guides/async-grpo.md b/docs/guides/async-grpo.md index 50f84b4e66..0beac8204a 100644 --- a/docs/guides/async-grpo.md +++ b/docs/guides/async-grpo.md @@ -41,6 +41,8 @@ grpo: async_grpo: enabled: true max_trajectory_age_steps: 1 # Maximum age, in training steps, for trajectories + in_flight_weight_updates: false # Enable for faster weight synchronization + recompute_kv_cache_after_weight_updates: false # Invalidates kv cache after in-flight-weight-updates ``` ### Complete Example Config @@ -65,6 +67,8 @@ grpo: async_grpo: enabled: true max_trajectory_age_steps: 1 + in_flight_weight_updates: false # Enable for faster weight synchronization + recompute_kv_cache_after_weight_updates: false # Invalidates kv cache after in-flight-weight-updates cluster: num_nodes: 2 @@ -158,6 +162,11 @@ sequenceDiagram 3. **Resource Allocation**: Ensure sufficient GPU memory for both the training and generation clusters +4. **In-Flight Weight Updates**: Enable `in_flight_weight_updates: true` when using `async_engine: true` for updating the weights of vLLM engine during generation. This prevents stalling training pipeline until longest generation finishes and provides significant performance benefits. + +5. **Recompute KV Cache After Weight Updates**: While using in-flight weight update, user can choose whether to recompute +KV caches after weight udpate by configuring `recompute_kv_cache_after_weight_update` configuration. + ## Why Importance Sampling Correction Is Required for Async ### The GRPO Objective diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index f4a636080c..1f50da778f 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -32,6 +32,8 @@ grpo: enabled: false # Set to true to enable async training mode # Max age (in training steps) for trajectories used in training max_trajectory_age_steps: 1 + in_flight_weight_updates: false # Set to true to enable in-flight weight updates + recompute_kv_cache_after_weight_updates: false # Set to true to recompute kv cache after in-flight-weight-updates loss_fn: reference_policy_kl_penalty: 0.01 diff --git a/nemo_rl/algorithms/async_utils.py b/nemo_rl/algorithms/async_utils.py index 55703a27b1..c1ce9ab762 100644 --- a/nemo_rl/algorithms/async_utils.py +++ b/nemo_rl/algorithms/async_utils.py @@ -44,7 +44,7 @@ def __init__(self, max_size: int): if max_size <= 0: raise ValueError(f"max_size must be positive, got {max_size}") self.max_size = max_size - self.trajectories = [] + self.trajectories = [] # List[dict[str, Any]] # If trajectory_version is 1 and target_weight_version is 4 it means that weight version 1 was used for generating a trajectory and this trajectory will be used for training when weight version is 4. self.trajectory_versions = [] # it is the weight-version used for generation of a trajectory self.target_weight_versions = [] # it is the weight-version of the trainer where this trajectory will be used. @@ -278,8 +278,12 @@ def __init__( self._inflight_threads: set[_threading.Thread] = set() self._threads_lock: _threading.Lock = _threading.Lock() - # Limit in-flight generator requests to num_prompts_per_step - max_inflight = int(self.master_config["grpo"]["num_prompts_per_step"]) or 1 + # Limit in-flight generator requests to num_prompts_per_step * max_trajectory_age_steps + # This value limits the parallelism of the generation requests. + max_inflight = ( + int(self.master_config["grpo"]["num_prompts_per_step"]) + * int(self.master_config["grpo"]["async_grpo"]["max_trajectory_age_steps"]) + ) or 1 self._inflight_sema = _threading.Semaphore(max_inflight) # Simple lock to prevent race conditions when checking/spawning workers @@ -475,6 +479,9 @@ def _process_batch(self, batch: BatchedDataDict[DatumSpec]) -> None: print( f"⏸️ Waiting for refit to complete before starting new generation ({active_threads} threads still active)" ) + print( + " Note: With vLLM V1 async engine, active threads can complete during weight update" + ) self._refit_pause_cleared.wait() # After refit finishes if weight version has updated, reflect that in the new trajectories @@ -520,7 +527,14 @@ def resume(self) -> None: print("Trajectory collection resumed") def prepare_for_refit(self) -> None: - """Pause new generation starts and wait for pending generations to complete before refit.""" + """Pause new generation starts and optionally wait for pending generations. + + For vLLM V1 async engine, leverages in-flight weight updates via collective_rpc, + allowing ongoing generations to continue with their current KV caches while + weights are updated. This significantly improves async performance. + + For non-async engines, waits for all pending generations to complete before refit. + """ start_time = time.time() print("πŸ”„ Preparing for refit: pausing new generations...") @@ -528,20 +542,62 @@ def prepare_for_refit(self) -> None: self._refit_pause_cleared.clear() print("⏸️ New generation starts paused") - # Wait for all pending generations to complete - # Note that is suboptimal for async performance and will be fixed in a follow-up PR where two more options will be added: - # 1. Pause the generations at their current decoding step, update the weights and continue with decoding. - # 2. Stop the current generations, store in a buffer and resume them in next iteration with new weights. - self.wait_for_pending_generations() + # Check if we're using vLLM async engine + vllm_cfg = ( + self.master_config.get("policy", {}) + .get("generation", {}) + .get("vllm_cfg", {}) + ) + is_async_engine = vllm_cfg.get("async_engine", False) + in_flight_weight_updates = ( + self.master_config.get("grpo", {}) + .get("async_grpo", {}) + .get("in_flight_weight_updates", False) + ) + + if is_async_engine and in_flight_weight_updates: + # vLLM V1 async engine supports in-flight weight updates + # Ongoing generations will continue with their current KV caches + # New generations (after weight update) will use the updated weights + print( + "πŸš€ Using vLLM V1 in-flight weight update - skipping wait for pending generations" + ) + print( + f" {len(self._inflight_threads)} ongoing generations will complete with current weights" + ) + else: + # For non-async engines, wait for all pending generations to complete + print( + "⏸️ Non-async engine: waiting for all pending generations to complete..." + ) + self.wait_for_pending_generations() elapsed = time.time() - start_time - print( - f"βœ… All pending generations completed, ready for refit (took {elapsed:.2f}s)" - ) + print(f"βœ… Ready for refit (took {elapsed:.2f}s)") def resume_after_refit(self) -> None: """Resume new generation starts after refit is complete.""" print("πŸ”„ Resuming generation starts after refit") + + # Invalidate&recompute vLLM caches after the in-flight weight updates if + # recompute_kv_cache_after_weight_updates is True (AREAL-style implementation). + # Otherwise, keep using the stale KV caches (Magistral-style implementation). + async_cfg = self.master_config.get("grpo", {}).get("async_grpo", {}) + if async_cfg.get("in_flight_weight_updates", False) and async_cfg.get( + "recompute_kv_cache_after_weight_updates", False + ): + try: + print("πŸ”„ Invalidating vLLM prefix/KV caches after weight update") + invalidated = self.policy_generation.invalidate_kv_cache() + if invalidated: + print("βœ… Invalidated vLLM prefix/KV caches after weight update") + else: + print( + "⚠️ vLLM cache invalidation reported partial/unsuccessful on some workers" + ) + except Exception as e: + print(f"⚠️ Failed to invalidate vLLM caches: {e}") + self._refit_pause_cleared.set() def wait_for_pending_generations(self) -> None: @@ -636,8 +692,8 @@ def _run_prompt_group_worker( ) break elif status == "full": - # Exponential backoff up to 1 second - time.sleep(min(backoff_delay, 1.0)) + # Exponential backoff up to 0.5 second + time.sleep(min(backoff_delay, 0.5)) backoff_delay *= 1.5 else: # Unexpected status, wait briefly diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 578daafda3..9d9de05ce4 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -102,6 +102,11 @@ class AsyncGRPOConfig(TypedDict): # async replay buffer. Trajectories older than this are excluded during # sampling; buffer sizing also scales with this value. max_trajectory_age_steps: int + # Does the weight synchronization as soon as the training is done + # without waiting for the pending generations to finish. + in_flight_weight_updates: NotRequired[bool] + # Recomputes the KV cache after the in-flight weight updates. + recompute_kv_cache_after_weight_updates: NotRequired[bool] class GRPOConfig(TypedDict): @@ -1512,6 +1517,16 @@ def async_grpo_train( assert master_config["loss_fn"]["use_importance_sampling_correction"] is True, ( "Importance sampling correction must be enabled for async GRPO for good convergence due to off-policy samples!" ) + + if master_config["grpo"]["async_grpo"]["max_trajectory_age_steps"] > 1: + if not master_config["grpo"]["async_grpo"].get( + "in_flight_weight_updates", False + ): + print( + "⚠️ WARNING: In-flight weight updates must be enabled for async GRPO with max_trajectory_age_steps > 1. " + "Without in-flight weight updates, having more max_trajectory_age_steps will not give any performance benefit." + ) + # Import async utilities only when needed from nemo_rl.algorithms.async_utils import AsyncTrajectoryCollector, ReplayBuffer @@ -1725,7 +1740,7 @@ def async_grpo_train( with timer.time("total_step_time"): # Sample trajectories from replay buffer print("πŸ“¦ Sampling from replay buffer...") - with timer.time("buffer_sampling"): + with timer.time("exposed_generation"): buffer_size_current = ray.get(replay_buffer.size.remote()) print( f"πŸ“Š Step coordination: training_step={step}, max_age={max_trajectory_age_steps}, buffer_size={buffer_size_current}" @@ -2078,6 +2093,7 @@ def async_grpo_train( tokenizer_path=os.path.join( checkpoint_path, "policy", "tokenizer" ), + checkpointing_cfg=master_config["checkpointing"], ) # Get dataloader state from trajectory collector actual_dataloader_state = ray.get( diff --git a/nemo_rl/models/generation/interfaces.py b/nemo_rl/models/generation/interfaces.py index 7b3ed190f5..f7f58b383f 100644 --- a/nemo_rl/models/generation/interfaces.py +++ b/nemo_rl/models/generation/interfaces.py @@ -247,3 +247,8 @@ def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]: def update_weights_from_collective(self) -> list[ray.ObjectRef]: """Update the model weights from collective communication.""" raise NotImplementedError + + # Optional hook; backends may override to invalidate any reusable caches + # (e.g., vLLM prefix/KV caches) after weight updates. + def invalidate_kv_cache(self) -> bool: + return False diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index 4efb492cd3..5dcc7eaf2e 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -827,3 +827,25 @@ def __del__(self) -> None: user calls shutdown(). """ self.shutdown() + + def invalidate_kv_cache(self) -> bool: + """Invalidate reusable caches in vLLM (e.g., prefix/KV cache) after weight updates. + + For async_engine, calls reset_prefix_cache_async on workers. For sync, calls reset_prefix_cache. + Returns True if all workers report success. + """ + try: + method_name = ( + "reset_prefix_cache_async" + if self.cfg["vllm_cfg"]["async_engine"] + else "reset_prefix_cache" + ) + futures = self.worker_group.run_all_workers_single_data( + method_name, + run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], + ) + results = ray.get(futures) + return all(result for result in results if result is not None) + except Exception as e: + print(f"Error invalidating vLLM caches: {e}") + return False diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index b69555ae15..7cea527aac 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -664,6 +664,10 @@ def finish_generation(self, *args: Any, **kwargs: Any) -> bool: # We don't need to do anything here return True + def invalidate_kv_cache(self, *args: Any, **kwargs: Any) -> bool: + # We don't need to do anything here + return True + def finish_training(self, *args: Any, **kwargs: Any) -> None: # Placeholder implementation pass diff --git a/nemo_rl/utils/logger.py b/nemo_rl/utils/logger.py index fa76f1295d..acf04d8a36 100644 --- a/nemo_rl/utils/logger.py +++ b/nemo_rl/utils/logger.py @@ -132,7 +132,19 @@ def log_metrics( for name, value in metrics.items(): if prefix: name = f"{prefix}/{name}" - self.writer.add_scalar(name, value, step) + + # Skip non-scalar values that TensorBoard can't handle + if isinstance(value, (dict, list)): + print( + f"Warning: Skipping non-scalar metric '{name}' for TensorBoard logging (type: {type(value).__name__})" + ) + continue + + try: + self.writer.add_scalar(name, value, step) + except Exception as e: + print(f"Warning: Failed to log metric '{name}' to TensorBoard: {e}") + continue def log_hyperparams(self, params: Mapping[str, Any]) -> None: """Log hyperparameters to Tensorboard. diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index 34ed2ef88f..b3e1a0c9e9 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -933,6 +933,10 @@ def val_iter(self): "reward_scaling": {"enabled": False}, "reward_shaping": {"enabled": False}, "use_dynamic_sampling": False, + "async_grpo": { + "enabled": False, + "max_trajectory_age_steps": 1, + }, }, "policy": { "train_global_batch_size": 1,