From 1c389de595560af208afad2d88419c644406a273 Mon Sep 17 00:00:00 2001 From: ruit Date: Tue, 13 Jan 2026 03:58:32 -0800 Subject: [PATCH 1/4] support lora async Signed-off-by: ruit --- nemo_rl/algorithms/grpo.py | 29 ++++++++++---- .../models/generation/vllm/vllm_backend.py | 14 +++++++ .../generation/vllm/vllm_worker_async.py | 40 ++++++++++++++++++- 3 files changed, 74 insertions(+), 9 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 78ad740858..eb582400f1 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -466,10 +466,6 @@ def setup( # Override the vLLM lora config with the DTensor lora config generation_config["vllm_cfg"]["lora_cfg"] = lora_cfg - assert not _should_use_async_rollouts(master_config), ( - "Async rollouts are not supported with LoRA in DTensor backend." - ) - # Define initialization functions that will be used in all paths def init_policy(): """Initialize policy training workers.""" @@ -1930,6 +1926,8 @@ def async_grpo_train( policy_generation = policy NEED_REFIT = False POLICY_GENERATION_STALE = True + REFIT_BASE_MODEL_WEIGHTS = True + REFIT_LORA_WEIGHTS = policy.lora_enabled assert policy_generation is not None # Training state @@ -2041,9 +2039,16 @@ def async_grpo_train( if NEED_REFIT and POLICY_GENERATION_STALE: print("🔄 Refitting policy generation with actual model weights...") try: - refit_policy_generation(policy, policy_generation, colocated_inference) + refit_policy_generation( + policy, + policy_generation, + colocated_inference, + refit_base_model_weights=REFIT_BASE_MODEL_WEIGHTS, + refit_lora_weights=REFIT_LORA_WEIGHTS, + ) print("✅ Policy generation refit completed successfully") POLICY_GENERATION_STALE = False + REFIT_BASE_MODEL_WEIGHTS = False if REFIT_LORA_WEIGHTS else True except Exception as e: print(f"❌ Policy generation refit failed: {e}") import traceback @@ -2361,9 +2366,14 @@ def async_grpo_train( print("🔄 Performing policy generation refit...") with timer.time("weight_sync"): refit_policy_generation( - policy, policy_generation, colocated_inference + policy, + policy_generation, + colocated_inference, + refit_base_model_weights=REFIT_BASE_MODEL_WEIGHTS, + refit_lora_weights=REFIT_LORA_WEIGHTS, ) POLICY_GENERATION_STALE = False + REFIT_BASE_MODEL_WEIGHTS = False if REFIT_LORA_WEIGHTS else True # Update weight version before resuming trajectory collection so that all trajectories are updated with the new correct weight version weight_version += 1 @@ -2386,9 +2396,14 @@ def async_grpo_train( if NEED_REFIT and POLICY_GENERATION_STALE: refit_policy_generation( - policy, policy_generation, colocated_inference + policy, + policy_generation, + colocated_inference, + refit_base_model_weights=REFIT_BASE_MODEL_WEIGHTS, + refit_lora_weights=REFIT_LORA_WEIGHTS, ) POLICY_GENERATION_STALE = False + REFIT_BASE_MODEL_WEIGHTS = False if REFIT_LORA_WEIGHTS else True else: policy_generation.prepare_for_generation() val_metrics, validation_timings = validate( diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index a0c855b2f7..fa4a577889 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -129,6 +129,20 @@ def _maybe_process_fp8_kv_cache(self) -> None: target_device, ) + def apply_lora_patches(self) -> None: + """Apply LoRA patches inside the vLLM worker process. Used for async worker.""" + try: + from nemo_rl.models.generation.lora import apply_lora_patches + + apply_lora_patches() + + except Exception as e: + print(f"Failed to apply LoRA patches in worker extension: {e}") + import traceback as _tb + + print(_tb.format_exc()) + raise e + def _apply_weight_name_mapping( self, weights: list[tuple[str, torch.Tensor]] ) -> list[tuple[str, torch.Tensor]]: diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index 83fef65a7d..63f08bb56d 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -276,6 +276,17 @@ def clear_vllm_logger_metrics(self) -> None: async def post_init_async(self): self.vllm_device_ids = await self.report_device_id_async() + # Ensure LoRA patches are applied inside engine worker processes (async path) + if getattr(self, "lora_enabled", False) and self.llm is not None: + try: + await self.llm.collective_rpc("apply_lora_patches", args=tuple()) + print( + "Successfully applied lora patches in engine workers (async worker)" + ) + except Exception as e: + print( + f"[WARNING] Failed to apply lora patches in engine workers (async worker): {e}" + ) async def report_dp_openai_server_base_url(self) -> Optional[str]: return self.base_url @@ -746,10 +757,21 @@ async def process_single_sample(sample_idx): request_id = str(uuid.uuid4()) + lora_req = None + if self.lora_enabled: + from vllm.lora.request import LoRARequest + from nemo_rl.models.generation.lora import get_vllm_lora_metadata + + lora_metadata = get_vllm_lora_metadata() + lora_req = LoRARequest( + **lora_metadata, + ) + # Generate using vLLM async engine vllm_request_generator = self.llm.generate( prompt=prompt, sampling_params=sampling_params_for_request, + lora_request=lora_req, request_id=request_id, ) @@ -919,10 +941,21 @@ async def process_single_prompt(prompt_idx): request_id = str(uuid.uuid4()) + lora_req = None + if self.lora_enabled: + from vllm.lora.request import LoRARequest + from nemo_rl.models.generation.lora import get_vllm_lora_metadata + + lora_metadata = get_vllm_lora_metadata() + lora_req = LoRARequest( + **lora_metadata, + ) + # Generate using vLLM async engine vllm_request_generator = self.llm.generate( prompt=prompt, sampling_params=sampling_params, + lora_request=lora_req, request_id=request_id, ) @@ -1027,7 +1060,9 @@ async def update_weights_via_ipc_zmq_async( traceback.print_exc() return False - async def update_weights_from_collective_async(self) -> bool: + async def update_weights_from_collective_async( + self, refit_base_model_weights: bool = True, refit_lora_weights: bool = False + ) -> bool: """Async version of update_weights_from_collective.""" try: assert self.llm is not None, ( @@ -1040,7 +1075,8 @@ async def update_weights_from_collective_async(self) -> bool: ) result_or_coro = await self.llm.collective_rpc( - "update_weights_from_collective", args=tuple() + "update_weights_from_collective", + args=(self.lora_cfg, refit_base_model_weights, refit_lora_weights), ) if asyncio.iscoroutine(result_or_coro): From f0ad4e33310d52dc89056671ebd8809cf3781e69 Mon Sep 17 00:00:00 2001 From: ruit Date: Thu, 15 Jan 2026 00:54:50 -0800 Subject: [PATCH 2/4] refactor: update weight refitting parameters to use a unified 'refit_mode' across multiple interfaces Signed-off-by: ruit --- nemo_rl/algorithms/grpo.py | 9 --------- nemo_rl/models/generation/vllm/vllm_backend.py | 2 +- nemo_rl/models/generation/vllm/vllm_worker_async.py | 9 +++++---- 3 files changed, 6 insertions(+), 14 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index eb582400f1..2968a9fea4 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -2043,12 +2043,9 @@ def async_grpo_train( policy, policy_generation, colocated_inference, - refit_base_model_weights=REFIT_BASE_MODEL_WEIGHTS, - refit_lora_weights=REFIT_LORA_WEIGHTS, ) print("✅ Policy generation refit completed successfully") POLICY_GENERATION_STALE = False - REFIT_BASE_MODEL_WEIGHTS = False if REFIT_LORA_WEIGHTS else True except Exception as e: print(f"❌ Policy generation refit failed: {e}") import traceback @@ -2369,11 +2366,8 @@ def async_grpo_train( policy, policy_generation, colocated_inference, - refit_base_model_weights=REFIT_BASE_MODEL_WEIGHTS, - refit_lora_weights=REFIT_LORA_WEIGHTS, ) POLICY_GENERATION_STALE = False - REFIT_BASE_MODEL_WEIGHTS = False if REFIT_LORA_WEIGHTS else True # Update weight version before resuming trajectory collection so that all trajectories are updated with the new correct weight version weight_version += 1 @@ -2399,11 +2393,8 @@ def async_grpo_train( policy, policy_generation, colocated_inference, - refit_base_model_weights=REFIT_BASE_MODEL_WEIGHTS, - refit_lora_weights=REFIT_LORA_WEIGHTS, ) POLICY_GENERATION_STALE = False - REFIT_BASE_MODEL_WEIGHTS = False if REFIT_LORA_WEIGHTS else True else: policy_generation.prepare_for_generation() val_metrics, validation_timings = validate( diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index fa4a577889..ca45e67b41 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -132,7 +132,7 @@ def _maybe_process_fp8_kv_cache(self) -> None: def apply_lora_patches(self) -> None: """Apply LoRA patches inside the vLLM worker process. Used for async worker.""" try: - from nemo_rl.models.generation.lora import apply_lora_patches + from nemo_rl.models.generation.vllm.lora import apply_lora_patches apply_lora_patches() diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index 63f08bb56d..86cc64e678 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -760,7 +760,7 @@ async def process_single_sample(sample_idx): lora_req = None if self.lora_enabled: from vllm.lora.request import LoRARequest - from nemo_rl.models.generation.lora import get_vllm_lora_metadata + from nemo_rl.models.generation.vllm.lora import get_vllm_lora_metadata lora_metadata = get_vllm_lora_metadata() lora_req = LoRARequest( @@ -944,7 +944,7 @@ async def process_single_prompt(prompt_idx): lora_req = None if self.lora_enabled: from vllm.lora.request import LoRARequest - from nemo_rl.models.generation.lora import get_vllm_lora_metadata + from nemo_rl.models.generation.vllm.lora import get_vllm_lora_metadata lora_metadata = get_vllm_lora_metadata() lora_req = LoRARequest( @@ -1061,7 +1061,8 @@ async def update_weights_via_ipc_zmq_async( return False async def update_weights_from_collective_async( - self, refit_base_model_weights: bool = True, refit_lora_weights: bool = False + self, + refit_mode: Optional[str] = "base_model", ) -> bool: """Async version of update_weights_from_collective.""" try: @@ -1076,7 +1077,7 @@ async def update_weights_from_collective_async( result_or_coro = await self.llm.collective_rpc( "update_weights_from_collective", - args=(self.lora_cfg, refit_base_model_weights, refit_lora_weights), + args=(self.lora_cfg, refit_mode), ) if asyncio.iscoroutine(result_or_coro): From cb7c69b630c7785b8922270e6a76590e6681daf9 Mon Sep 17 00:00:00 2001 From: ruit Date: Thu, 15 Jan 2026 01:36:13 -0800 Subject: [PATCH 3/4] add functional test Signed-off-by: ruit --- nemo_rl/algorithms/grpo.py | 2 - tests/functional/L1_Functional_Tests_GPU.sh | 1 + tests/functional/grpo_automodel_lora_async.sh | 52 +++++++++++++++++++ 3 files changed, 53 insertions(+), 2 deletions(-) create mode 100755 tests/functional/grpo_automodel_lora_async.sh diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 2968a9fea4..df97ff1366 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1926,8 +1926,6 @@ def async_grpo_train( policy_generation = policy NEED_REFIT = False POLICY_GENERATION_STALE = True - REFIT_BASE_MODEL_WEIGHTS = True - REFIT_LORA_WEIGHTS = policy.lora_enabled assert policy_generation is not None # Training state diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh index d22c3ba5c8..4cb1e6b2fb 100755 --- a/tests/functional/L1_Functional_Tests_GPU.sh +++ b/tests/functional/L1_Functional_Tests_GPU.sh @@ -27,6 +27,7 @@ time uv run --no-sync bash ./tests/functional/sft.sh time uv run --no-sync bash ./tests/functional/sft_resume_diamond.sh time uv run --no-sync bash ./tests/functional/grpo.sh time uv run --no-sync bash ./tests/functional/grpo_async.sh +time uv run --no-sync bash ./tests/functional/grpo_automodel_lora_async.sh time uv run --no-sync bash ./tests/functional/grpo_automodel_lora_non_colocated.sh time uv run --no-sync bash ./tests/functional/grpo_automodel_lora.sh time uv run --no-sync bash ./tests/functional/grpo_megatron.sh diff --git a/tests/functional/grpo_automodel_lora_async.sh b/tests/functional/grpo_automodel_lora_async.sh new file mode 100755 index 0000000000..c3a99fedd0 --- /dev/null +++ b/tests/functional/grpo_automodel_lora_async.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +# clean up checkpoint directory on exit +trap "rm -rf /tmp/lora_sft_checkpoints" EXIT + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +NRL_FORCE_REBUILD_VENVS=true uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_grpo_math.py\ + grpo.max_num_steps=3 \ + grpo.num_prompts_per_step=8 \ + grpo.num_generations_per_prompt=4 \ + policy.dtensor_cfg.lora_cfg.enabled=True \ + policy.dtensor_cfg.lora_cfg.dim=32 \ + policy.train_global_batch_size=32 \ + policy.train_micro_batch_size=1 \ + policy.generation.colocated.enabled=false \ + policy.generation.colocated.resources.gpus_per_node=1 \ + policy.generation.colocated.resources.num_nodes=1 \ + policy.generation.vllm_cfg.async_engine=true \ + grpo.async_grpo.enabled=true \ + loss_fn.use_importance_sampling_correction=true \ + cluster.gpus_per_node=2 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.monitor_gpus=true \ + checkpointing.enabled=false \ + "$@" \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'data["train/reward"]["3"] > 0.06' + From daaba23d190687c908a9bf7fdbd333bb7322092f Mon Sep 17 00:00:00 2001 From: ruit Date: Thu, 15 Jan 2026 01:47:58 -0800 Subject: [PATCH 4/4] add unit test Signed-off-by: ruit --- nemo_rl/algorithms/grpo.py | 14 +++----------- .../unit/models/generation/test_vllm_generation.py | 4 ++++ 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index df97ff1366..bbc2e19a43 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -2037,11 +2037,7 @@ def async_grpo_train( if NEED_REFIT and POLICY_GENERATION_STALE: print("🔄 Refitting policy generation with actual model weights...") try: - refit_policy_generation( - policy, - policy_generation, - colocated_inference, - ) + refit_policy_generation(policy, policy_generation, colocated_inference) print("✅ Policy generation refit completed successfully") POLICY_GENERATION_STALE = False except Exception as e: @@ -2361,9 +2357,7 @@ def async_grpo_train( print("🔄 Performing policy generation refit...") with timer.time("weight_sync"): refit_policy_generation( - policy, - policy_generation, - colocated_inference, + policy, policy_generation, colocated_inference ) POLICY_GENERATION_STALE = False @@ -2388,9 +2382,7 @@ def async_grpo_train( if NEED_REFIT and POLICY_GENERATION_STALE: refit_policy_generation( - policy, - policy_generation, - colocated_inference, + policy, policy_generation, colocated_inference ) POLICY_GENERATION_STALE = False else: diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 3f1d18c200..0244501e68 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -898,6 +898,8 @@ async def run_hf_train_process( # LoRA tests (False, False, "bfloat16", True), (False, True, "bfloat16", True), + (True, False, "bfloat16", True), + (True, True, "bfloat16", True), ], ) async def test_vllm_generation_with_hf_training_colocated( @@ -964,6 +966,8 @@ async def test_vllm_generation_with_hf_training_colocated( # LoRA tests (False, False, "bfloat16", True), (False, True, "bfloat16", True), + (True, False, "bfloat16", True), + (True, True, "bfloat16", True), ], ) async def test_vllm_generation_with_hf_training_non_colocated(