diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 78ad740858..bbc2e19a43 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -466,10 +466,6 @@ def setup( # Override the vLLM lora config with the DTensor lora config generation_config["vllm_cfg"]["lora_cfg"] = lora_cfg - assert not _should_use_async_rollouts(master_config), ( - "Async rollouts are not supported with LoRA in DTensor backend." - ) - # Define initialization functions that will be used in all paths def init_policy(): """Initialize policy training workers.""" diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index a0c855b2f7..ca45e67b41 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -129,6 +129,20 @@ def _maybe_process_fp8_kv_cache(self) -> None: target_device, ) + def apply_lora_patches(self) -> None: + """Apply LoRA patches inside the vLLM worker process. Used for async worker.""" + try: + from nemo_rl.models.generation.vllm.lora import apply_lora_patches + + apply_lora_patches() + + except Exception as e: + print(f"Failed to apply LoRA patches in worker extension: {e}") + import traceback as _tb + + print(_tb.format_exc()) + raise e + def _apply_weight_name_mapping( self, weights: list[tuple[str, torch.Tensor]] ) -> list[tuple[str, torch.Tensor]]: diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index 83fef65a7d..86cc64e678 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -276,6 +276,17 @@ def clear_vllm_logger_metrics(self) -> None: async def post_init_async(self): self.vllm_device_ids = await self.report_device_id_async() + # Ensure LoRA patches are applied inside engine worker processes (async path) + if getattr(self, "lora_enabled", False) and self.llm is not None: + try: + await self.llm.collective_rpc("apply_lora_patches", args=tuple()) + print( + "Successfully applied lora patches in engine workers (async worker)" + ) + except Exception as e: + print( + f"[WARNING] Failed to apply lora patches in engine workers (async worker): {e}" + ) async def report_dp_openai_server_base_url(self) -> Optional[str]: return self.base_url @@ -746,10 +757,21 @@ async def process_single_sample(sample_idx): request_id = str(uuid.uuid4()) + lora_req = None + if self.lora_enabled: + from vllm.lora.request import LoRARequest + from nemo_rl.models.generation.vllm.lora import get_vllm_lora_metadata + + lora_metadata = get_vllm_lora_metadata() + lora_req = LoRARequest( + **lora_metadata, + ) + # Generate using vLLM async engine vllm_request_generator = self.llm.generate( prompt=prompt, sampling_params=sampling_params_for_request, + lora_request=lora_req, request_id=request_id, ) @@ -919,10 +941,21 @@ async def process_single_prompt(prompt_idx): request_id = str(uuid.uuid4()) + lora_req = None + if self.lora_enabled: + from vllm.lora.request import LoRARequest + from nemo_rl.models.generation.vllm.lora import get_vllm_lora_metadata + + lora_metadata = get_vllm_lora_metadata() + lora_req = LoRARequest( + **lora_metadata, + ) + # Generate using vLLM async engine vllm_request_generator = self.llm.generate( prompt=prompt, sampling_params=sampling_params, + lora_request=lora_req, request_id=request_id, ) @@ -1027,7 +1060,10 @@ async def update_weights_via_ipc_zmq_async( traceback.print_exc() return False - async def update_weights_from_collective_async(self) -> bool: + async def update_weights_from_collective_async( + self, + refit_mode: Optional[str] = "base_model", + ) -> bool: """Async version of update_weights_from_collective.""" try: assert self.llm is not None, ( @@ -1040,7 +1076,8 @@ async def update_weights_from_collective_async(self) -> bool: ) result_or_coro = await self.llm.collective_rpc( - "update_weights_from_collective", args=tuple() + "update_weights_from_collective", + args=(self.lora_cfg, refit_mode), ) if asyncio.iscoroutine(result_or_coro): diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh index d22c3ba5c8..4cb1e6b2fb 100755 --- a/tests/functional/L1_Functional_Tests_GPU.sh +++ b/tests/functional/L1_Functional_Tests_GPU.sh @@ -27,6 +27,7 @@ time uv run --no-sync bash ./tests/functional/sft.sh time uv run --no-sync bash ./tests/functional/sft_resume_diamond.sh time uv run --no-sync bash ./tests/functional/grpo.sh time uv run --no-sync bash ./tests/functional/grpo_async.sh +time uv run --no-sync bash ./tests/functional/grpo_automodel_lora_async.sh time uv run --no-sync bash ./tests/functional/grpo_automodel_lora_non_colocated.sh time uv run --no-sync bash ./tests/functional/grpo_automodel_lora.sh time uv run --no-sync bash ./tests/functional/grpo_megatron.sh diff --git a/tests/functional/grpo_automodel_lora_async.sh b/tests/functional/grpo_automodel_lora_async.sh new file mode 100755 index 0000000000..c3a99fedd0 --- /dev/null +++ b/tests/functional/grpo_automodel_lora_async.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +# clean up checkpoint directory on exit +trap "rm -rf /tmp/lora_sft_checkpoints" EXIT + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +NRL_FORCE_REBUILD_VENVS=true uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_grpo_math.py\ + grpo.max_num_steps=3 \ + grpo.num_prompts_per_step=8 \ + grpo.num_generations_per_prompt=4 \ + policy.dtensor_cfg.lora_cfg.enabled=True \ + policy.dtensor_cfg.lora_cfg.dim=32 \ + policy.train_global_batch_size=32 \ + policy.train_micro_batch_size=1 \ + policy.generation.colocated.enabled=false \ + policy.generation.colocated.resources.gpus_per_node=1 \ + policy.generation.colocated.resources.num_nodes=1 \ + policy.generation.vllm_cfg.async_engine=true \ + grpo.async_grpo.enabled=true \ + loss_fn.use_importance_sampling_correction=true \ + cluster.gpus_per_node=2 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.monitor_gpus=true \ + checkpointing.enabled=false \ + "$@" \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'data["train/reward"]["3"] > 0.06' + diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 3f1d18c200..0244501e68 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -898,6 +898,8 @@ async def run_hf_train_process( # LoRA tests (False, False, "bfloat16", True), (False, True, "bfloat16", True), + (True, False, "bfloat16", True), + (True, True, "bfloat16", True), ], ) async def test_vllm_generation_with_hf_training_colocated( @@ -964,6 +966,8 @@ async def test_vllm_generation_with_hf_training_colocated( # LoRA tests (False, False, "bfloat16", True), (False, True, "bfloat16", True), + (True, False, "bfloat16", True), + (True, True, "bfloat16", True), ], ) async def test_vllm_generation_with_hf_training_non_colocated(