diff --git a/nemo_rl/algorithms/async_utils.py b/nemo_rl/algorithms/async_utils.py index 0e2d80ae01..67a9b5b448 100644 --- a/nemo_rl/algorithms/async_utils.py +++ b/nemo_rl/algorithms/async_utils.py @@ -469,6 +469,9 @@ def _process_batch(self, batch: BatchedDataDict[DatumSpec]) -> None: f"⏸️ Waiting for refit to complete before starting new generation ({active_threads} threads still active)" ) self._refit_pause_cleared.wait() + + # After refit finishes if weight version has updated, reflect that in the new trajectories + generation_weight_version = self.current_weight_version single_prompt_batch = batch.slice(prompt_idx, prompt_idx + 1) repeated_batch = single_prompt_batch.repeat_interleave(num_generations) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 45633e7b4d..50f02ec362 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1487,11 +1487,12 @@ def async_grpo_train( ) POLICY_GENERATION_STALE = False - trajectory_collector.resume_after_refit.remote() - - # Notify collector about the new weight version (post-update) - weight_version += 1 - trajectory_collector.set_weight_version.remote(weight_version) + #Update weight version before resuming trajectory collection so that all trajectories are updated with the new correct weight version + weight_version += 1 + trajectory_collector.set_weight_version.remote(weight_version) + trajectory_collector.resume_after_refit.remote() + + # Validation val_metrics, validation_timings = None, None