Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,6 @@ def setup(
# Override the vLLM lora config with the DTensor lora config
generation_config["vllm_cfg"]["lora_cfg"] = lora_cfg

assert not _should_use_async_rollouts(master_config), (
"Async rollouts are not supported with LoRA in DTensor backend."
)

# Define initialization functions that will be used in all paths
def init_policy():
"""Initialize policy training workers."""
Expand Down
14 changes: 14 additions & 0 deletions nemo_rl/models/generation/vllm/vllm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,20 @@ def _maybe_process_fp8_kv_cache(self) -> None:
target_device,
)

def apply_lora_patches(self) -> None:
"""Apply LoRA patches inside the vLLM worker process. Used for async worker."""
try:
from nemo_rl.models.generation.vllm.lora import apply_lora_patches

apply_lora_patches()

except Exception as e:
print(f"Failed to apply LoRA patches in worker extension: {e}")
import traceback as _tb

print(_tb.format_exc())
raise e

def _apply_weight_name_mapping(
self, weights: list[tuple[str, torch.Tensor]]
) -> list[tuple[str, torch.Tensor]]:
Expand Down
41 changes: 39 additions & 2 deletions nemo_rl/models/generation/vllm/vllm_worker_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,17 @@ def clear_vllm_logger_metrics(self) -> None:

async def post_init_async(self):
self.vllm_device_ids = await self.report_device_id_async()
# Ensure LoRA patches are applied inside engine worker processes (async path)
if getattr(self, "lora_enabled", False) and self.llm is not None:
try:
await self.llm.collective_rpc("apply_lora_patches", args=tuple())
print(
"Successfully applied lora patches in engine workers (async worker)"
)
except Exception as e:
print(
f"[WARNING] Failed to apply lora patches in engine workers (async worker): {e}"
)

async def report_dp_openai_server_base_url(self) -> Optional[str]:
return self.base_url
Expand Down Expand Up @@ -746,10 +757,21 @@ async def process_single_sample(sample_idx):

request_id = str(uuid.uuid4())

lora_req = None
if self.lora_enabled:
from vllm.lora.request import LoRARequest
from nemo_rl.models.generation.vllm.lora import get_vllm_lora_metadata

lora_metadata = get_vllm_lora_metadata()
lora_req = LoRARequest(
**lora_metadata,
)

# Generate using vLLM async engine
vllm_request_generator = self.llm.generate(
prompt=prompt,
sampling_params=sampling_params_for_request,
lora_request=lora_req,
request_id=request_id,
)

Expand Down Expand Up @@ -919,10 +941,21 @@ async def process_single_prompt(prompt_idx):

request_id = str(uuid.uuid4())

lora_req = None
if self.lora_enabled:
from vllm.lora.request import LoRARequest
from nemo_rl.models.generation.vllm.lora import get_vllm_lora_metadata

lora_metadata = get_vllm_lora_metadata()
lora_req = LoRARequest(
**lora_metadata,
)

# Generate using vLLM async engine
vllm_request_generator = self.llm.generate(
prompt=prompt,
sampling_params=sampling_params,
lora_request=lora_req,
request_id=request_id,
)

Expand Down Expand Up @@ -1027,7 +1060,10 @@ async def update_weights_via_ipc_zmq_async(
traceback.print_exc()
return False

async def update_weights_from_collective_async(self) -> bool:
async def update_weights_from_collective_async(
self,
refit_mode: Optional[str] = "base_model",
) -> bool:
"""Async version of update_weights_from_collective."""
try:
assert self.llm is not None, (
Expand All @@ -1040,7 +1076,8 @@ async def update_weights_from_collective_async(self) -> bool:
)

result_or_coro = await self.llm.collective_rpc(
"update_weights_from_collective", args=tuple()
"update_weights_from_collective",
args=(self.lora_cfg, refit_mode),
)

if asyncio.iscoroutine(result_or_coro):
Expand Down
1 change: 1 addition & 0 deletions tests/functional/L1_Functional_Tests_GPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ time uv run --no-sync bash ./tests/functional/sft.sh
time uv run --no-sync bash ./tests/functional/sft_resume_diamond.sh
time uv run --no-sync bash ./tests/functional/grpo.sh
time uv run --no-sync bash ./tests/functional/grpo_async.sh
time uv run --no-sync bash ./tests/functional/grpo_automodel_lora_async.sh
time uv run --no-sync bash ./tests/functional/grpo_automodel_lora_non_colocated.sh
time uv run --no-sync bash ./tests/functional/grpo_automodel_lora.sh
time uv run --no-sync bash ./tests/functional/grpo_megatron.sh
Expand Down
52 changes: 52 additions & 0 deletions tests/functional/grpo_automodel_lora_async.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#!/bin/bash

# clean up checkpoint directory on exit
trap "rm -rf /tmp/lora_sft_checkpoints" EXIT

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..)
# Mark the current repo as safe, since wandb fetches metadata about the repo
git config --global --add safe.directory $PROJECT_ROOT

set -eou pipefail

EXP_NAME=$(basename $0 .sh)
EXP_DIR=$SCRIPT_DIR/$EXP_NAME
LOG_DIR=$EXP_DIR/logs
JSON_METRICS=$EXP_DIR/metrics.json
RUN_LOG=$EXP_DIR/run.log
export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-}

rm -rf $EXP_DIR $LOG_DIR
mkdir -p $EXP_DIR $LOG_DIR

cd $PROJECT_ROOT
NRL_FORCE_REBUILD_VENVS=true uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \
$PROJECT_ROOT/examples/run_grpo_math.py\
grpo.max_num_steps=3 \
grpo.num_prompts_per_step=8 \
grpo.num_generations_per_prompt=4 \
policy.dtensor_cfg.lora_cfg.enabled=True \
policy.dtensor_cfg.lora_cfg.dim=32 \
policy.train_global_batch_size=32 \
policy.train_micro_batch_size=1 \
policy.generation.colocated.enabled=false \
policy.generation.colocated.resources.gpus_per_node=1 \
policy.generation.colocated.resources.num_nodes=1 \
policy.generation.vllm_cfg.async_engine=true \
grpo.async_grpo.enabled=true \
loss_fn.use_importance_sampling_correction=true \
cluster.gpus_per_node=2 \
logger.tensorboard_enabled=true \
logger.log_dir=$LOG_DIR \
logger.wandb_enabled=false \
logger.monitor_gpus=true \
checkpointing.enabled=false \
"$@" \
2>&1 | tee $RUN_LOG

uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS

uv run tests/check_metrics.py $JSON_METRICS \
'data["train/reward"]["3"] > 0.06'

4 changes: 4 additions & 0 deletions tests/unit/models/generation/test_vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,8 @@ async def run_hf_train_process(
# LoRA tests
(False, False, "bfloat16", True),
(False, True, "bfloat16", True),
(True, False, "bfloat16", True),
(True, True, "bfloat16", True),
],
)
async def test_vllm_generation_with_hf_training_colocated(
Expand Down Expand Up @@ -964,6 +966,8 @@ async def test_vllm_generation_with_hf_training_colocated(
# LoRA tests
(False, False, "bfloat16", True),
(False, True, "bfloat16", True),
(True, False, "bfloat16", True),
(True, True, "bfloat16", True),
],
)
async def test_vllm_generation_with_hf_training_non_colocated(
Expand Down
Loading