diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index 0eda853375..06edc440d4 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -277,9 +277,10 @@ def refit_policy_generation( """Refit the policy generation interface with the latest policy weights.""" policy.offload_before_refit() ipc_handles = policy.get_weights_ipc_handles() - policy_generation.prepare_for_generation() + policy_generation.prepare_for_generation(tags=["weights"]) policy_generation.update_weights(ipc_handles) policy.offload_after_refit() + policy_generation.prepare_for_generation(tags=["kv_cache"]) def generate_responses( diff --git a/nemo_reinforcer/models/generation/vllm.py b/nemo_reinforcer/models/generation/vllm.py index c9676a2f2b..343fe2d5a5 100644 --- a/nemo_reinforcer/models/generation/vllm.py +++ b/nemo_reinforcer/models/generation/vllm.py @@ -410,8 +410,12 @@ def sleep(self): gc.collect() torch.cuda.empty_cache() - def wake_up(self): - self.llm.wake_up() + def wake_up(self, **kwargs): + # tags like ["weights", "kv_cache"] + if "tags" in kwargs: + self.llm.wake_up(tags=kwargs["tags"]) + else: + self.llm.wake_up() class VllmGeneration(GenerationInterface): @@ -580,7 +584,7 @@ def prepare_for_generation(self, *args, **kwargs): try: # Use run_all_workers_single_data for methods that don't need data futures = self.worker_group.run_all_workers_single_data( - "wake_up", respect_tied_workers=True + "wake_up", respect_tied_workers=True, **kwargs ) # Wait for all futures to complete results = ray.get(futures)