diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index a7a9fe86b6..99bf8615f1 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -90,6 +90,19 @@ policy: tensor_parallel_size: 1 context_parallel_size: 1 custom_parallel_plan: null + + # LoRA (Low-Rank Adaptation) Configuration + lora_cfg: + enabled: False # Set to True to enable LoRA fine-tuning + target_modules: [] # List of module names to apply LoRA (empty list with match_all_linear=true applies to all linear layers) + exclude_modules: [] # List of module names to exclude from LoRA + match_all_linear: true # If True, applies LoRA to all linear layers (overrides target_modules) + dim: 8 # LoRA rank (r): lower rank = fewer parameters but less capacity. Typical values: 4, 8, 16, 32, 64 + alpha: 32 # LoRA scaling factor: effective learning rate multiplier = alpha/dim. Typical values: 16, 32, 64 + dropout: 0.0 # Dropout probability applied to LoRA layers (0.0 = no dropout) + dropout_position: "post" # Where to apply dropout: "pre" (before LoRA) or "post" (after LoRA) + lora_A_init: "xavier" # Initialization method for LoRA A matrix: "xavier" or "uniform" + use_triton: true # Use Triton-optimized kernels for LoRA (faster but requires flash-attn). Disable when tensor_parallel_size > 1 megatron_cfg: enabled: false diff --git a/examples/configs/recipes/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.yaml b/examples/configs/recipes/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.yaml new file mode 100644 index 0000000000..f5dc334359 --- /dev/null +++ b/examples/configs/recipes/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.yaml @@ -0,0 +1,25 @@ +defaults: ../../grpo_math_1B.yaml +grpo: + val_at_start: true +checkpointing: + checkpoint_dir: results/grpo-qwen3-8B-base-1n8g-fsdp2-lora +policy: + model_name: Qwen/Qwen3-8B-Base + max_total_sequence_length: 2048 + dtensor_cfg: + activation_checkpointing: true + lora_cfg: + enabled: True + dim: 128 + alpha: 128 + sequence_packing: + enabled: false +logger: + log_dir: logs/grpo-qwen3-8B-base-1n8g-fsdp2-lora + wandb_enabled: true + tensorboard_enabled: true + wandb: + project: nemo-rl + name: grpo-qwen3-8B-base-1n8g-fsdp2-lora +cluster: + gpus_per_node: 8 diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 3bc44367d8..38d25e8a6f 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -21,6 +21,7 @@ import ray import torch +from nemo_automodel.components._peft.lora import LinearLoRA from nemo_automodel.components.distributed.cp_utils import ( create_context_parallel_ctx, ) @@ -85,19 +86,66 @@ def dtensor_params_generator( Args: model: The model whose parameters to generate. target_dtype: The dtype to convert tensors to. + peft_config: Optional LoRA config for filtering which layers to merge. Yields: Tuples of (fully_qualified_name, tensor) where tensors are converted to target dtype and made contiguous. """ + module_map = dict(model.named_modules()) for name, tensor in model.state_dict().items(): + if name.endswith(".lora_A.weight") or name.endswith(".lora_B.weight"): + continue full_tensor = tensor.full_tensor() if isinstance(tensor, DTensor) else tensor - adapted_fqn_tensors = _maybe_adapt_tensor_to_hf(model, name, full_tensor) + merged_tensor = _maybe_merge_lora_weight(module_map, name, full_tensor) + + adapted_fqn_tensors = _maybe_adapt_tensor_to_hf(model, name, merged_tensor) for adapted_fqn, adapted_tensor in adapted_fqn_tensors: # Convert to target dtype yield ( adapted_fqn, adapted_tensor.to(target_dtype, non_blocking=True).contiguous(), ) + del adapted_tensor + del adapted_fqn_tensors + del merged_tensor + del full_tensor + + +@torch.no_grad() +def _maybe_merge_lora_weight( + module_map: dict[str, nn.Module], + fqn: str, + tensor: torch.Tensor, +) -> torch.Tensor: + if not fqn.endswith(".weight"): + return tensor + module_name = fqn[: -len(".weight")] + module = module_map.get(module_name) + if not isinstance(module, LinearLoRA): + return tensor + if not (hasattr(module, "lora_A") and hasattr(module, "lora_B")): + return tensor + + lora_a = ( + module.lora_A.weight.full_tensor() + if isinstance(module.lora_A.weight, DTensor) + else module.lora_A.weight + ) + lora_b = ( + module.lora_B.weight.full_tensor() + if isinstance(module.lora_B.weight, DTensor) + else module.lora_B.weight + ) + lora_a = lora_a.to(device=tensor.device, dtype=tensor.dtype) + lora_b = lora_b.to(device=tensor.device, dtype=tensor.dtype) + scale = getattr(module, "scale", None) + + if scale is None and hasattr(module, "alpha") and hasattr(module, "dim"): + scale = module.alpha / module.dim + if scale is None: + scale = 1.0 + + return tensor + torch.matmul(lora_b, lora_a) * scale def _maybe_adapt_tensor_to_hf( @@ -1208,6 +1256,8 @@ def prepare_refit_info(self) -> Optional[dict[str, Any]]: """Prepare state dict metadata for weight refitting and IPC streaming.""" state_dict_info = {} for name, tensor in self.model.state_dict().items(): + if name.endswith(".lora_A.weight") or name.endswith(".lora_B.weight"): + continue full_tensor = ( tensor.full_tensor() if isinstance(tensor, DTensor) else tensor ) diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh index 095a01c447..311b387e91 100644 --- a/tests/functional/L1_Functional_Tests_GPU.sh +++ b/tests/functional/L1_Functional_Tests_GPU.sh @@ -27,6 +27,9 @@ time uv run --no-sync bash ./tests/functional/sft.sh time uv run --no-sync bash ./tests/functional/sft_resume_diamond.sh time uv run --no-sync bash ./tests/functional/grpo.sh time uv run --no-sync bash ./tests/functional/grpo_async.sh +time uv run --no-sync bash ./tests/functional/grpo_automodel_lora_async.sh +time uv run --no-sync bash ./tests/functional/grpo_automodel_lora_non_colocated.sh +time uv run --no-sync bash ./tests/functional/grpo_automodel_lora.sh time uv run --no-sync bash ./tests/functional/grpo_megatron.sh time uv run --no-sync bash ./tests/functional/grpo_megatron_generation.sh time uv run --no-sync bash ./tests/functional/grpo_multiturn.sh diff --git a/tests/functional/grpo_automodel_lora.sh b/tests/functional/grpo_automodel_lora.sh new file mode 100755 index 0000000000..4cd7d2bcc1 --- /dev/null +++ b/tests/functional/grpo_automodel_lora.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# clean up checkpoint directory on exit +trap "rm -rf /tmp/lora_sft_checkpoints" EXIT + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_grpo_math.py\ + grpo.max_num_steps=3 \ + grpo.num_prompts_per_step=8 \ + grpo.num_generations_per_prompt=4 \ + data.shuffle=false \ + policy.dtensor_cfg.lora_cfg.enabled=True \ + policy.dtensor_cfg.lora_cfg.dim=32 \ + policy.train_global_batch_size=32 \ + policy.train_micro_batch_size=1 \ + cluster.gpus_per_node=2 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.monitor_gpus=true \ + checkpointing.enabled=false \ + "$@" \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'max(data["train/reward"]) > 0.03' diff --git a/tests/functional/grpo_automodel_lora_async.sh b/tests/functional/grpo_automodel_lora_async.sh new file mode 100755 index 0000000000..26e17ec992 --- /dev/null +++ b/tests/functional/grpo_automodel_lora_async.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +# clean up checkpoint directory on exit +trap "rm -rf /tmp/lora_sft_checkpoints" EXIT + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +NRL_FORCE_REBUILD_VENVS=true uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_grpo_math.py\ + grpo.max_num_steps=3 \ + grpo.num_prompts_per_step=8 \ + grpo.num_generations_per_prompt=4 \ + data.shuffle=false \ + policy.dtensor_cfg.lora_cfg.enabled=True \ + policy.dtensor_cfg.lora_cfg.dim=32 \ + policy.train_global_batch_size=32 \ + policy.train_micro_batch_size=1 \ + policy.generation.colocated.enabled=false \ + policy.generation.colocated.resources.gpus_per_node=1 \ + policy.generation.colocated.resources.num_nodes=1 \ + policy.generation.vllm_cfg.async_engine=true \ + grpo.async_grpo.enabled=true \ + loss_fn.use_importance_sampling_correction=true \ + cluster.gpus_per_node=2 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.monitor_gpus=true \ + checkpointing.enabled=false \ + "$@" \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'max(data["train/reward"]) > 0.03' diff --git a/tests/functional/grpo_automodel_lora_non_colocated.sh b/tests/functional/grpo_automodel_lora_non_colocated.sh new file mode 100755 index 0000000000..395b46d814 --- /dev/null +++ b/tests/functional/grpo_automodel_lora_non_colocated.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +# clean up checkpoint directory on exit +trap "rm -rf /tmp/lora_sft_checkpoints" EXIT + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +NRL_FORCE_REBUILD_VENVS=true uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_grpo_math.py\ + grpo.max_num_steps=3 \ + grpo.num_prompts_per_step=8 \ + grpo.num_generations_per_prompt=4 \ + data.shuffle=false \ + policy.dtensor_cfg.lora_cfg.enabled=True \ + policy.dtensor_cfg.lora_cfg.dim=32 \ + policy.train_global_batch_size=32 \ + policy.train_micro_batch_size=1 \ + policy.generation.colocated.enabled=false \ + policy.generation.colocated.resources.gpus_per_node=1 \ + policy.generation.colocated.resources.num_nodes=1 \ + cluster.gpus_per_node=2 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.monitor_gpus=true \ + checkpointing.enabled=false \ + "$@" \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'max(data["train/reward"]) > 0.03' diff --git a/tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh b/tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh new file mode 100755 index 0000000000..ecca27fb0e --- /dev/null +++ b/tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh @@ -0,0 +1,44 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=20 +MAX_STEPS=20 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=30 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo_math.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'mean(data["train/gen_kl_error"], 20) < 0.002' \ + 'data["train/gen_kl_error"]["20"] < 0.002' \ + 'max(data["train/reward"]) > 0.35' \ + 'mean(data["timing/train/total_step_time"], 2) < 80' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi \ No newline at end of file diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index 9e37d7ed01..19921ec2b3 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -66,6 +66,9 @@ tests/test_suites/llm/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated.sh tests/test_suites/llm/grpo-nano-v2-12b-1n8g-megatron.sh tests/test_suites/llm/grpo-nano-v2-12b-2n8g-fsdp2tp1.sh +# lora +tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh + ####### # SFT # ####### diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 87730c8908..c27a183b5c 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -35,7 +35,7 @@ from nemo_rl.models.generation.vllm.vllm_worker_async import ( _replace_prefix_tokens, ) -from nemo_rl.models.policy import PolicyConfig +from nemo_rl.models.policy import LoRAConfig, PolicyConfig from nemo_rl.models.policy.lm_policy import Policy model_name = "Qwen/Qwen3-0.6B" @@ -70,6 +70,7 @@ "skip_tokenizer_init": False, "load_format": "auto", "enforce_eager": "False", + "kv_cache_dtype": "auto", }, "colocated": { "enabled": True, @@ -105,6 +106,7 @@ }, }, "dtensor_cfg": { + "_v2": False, "enabled": True, "cpu_offload": False, "sequence_parallel": False, @@ -127,6 +129,19 @@ "generation": deepcopy(basic_vllm_test_config), } +basic_lora_test_config: LoRAConfig = { + "enabled": False, + "target_modules": [], + "exclude_modules": [], + "match_all_linear": True, + "dim": 8, + "alpha": 32, + "dropout": 0.0, + "dropout_position": "post", + "lora_A_init": "xavier", + "use_triton": False, +} + def get_basic_megatron_test_config( tp: int = 1, @@ -691,7 +706,13 @@ def configure_worker_fixed_seed(num_gpus, bundle_indices=None): async def run_hf_train_process( - lm_policy, vllm_policy, tokenizer, async_engine, colocated, vllm_precision + lm_policy, + vllm_policy, + tokenizer, + async_engine, + colocated, + vllm_precision, + enable_lora, ): """Validates that the two policies can work together. @@ -871,16 +892,19 @@ async def run_hf_train_process( @pytest.mark.timeout(300) @pytest.mark.asyncio @pytest.mark.parametrize( - ("async_engine", "cpu_offload", "vllm_precision"), + ("async_engine", "cpu_offload", "vllm_precision", "enable_lora"), [ - (True, False, "bfloat16"), - (False, True, "bfloat16"), - (True, False, "fp8"), - (False, True, "fp8"), + (True, False, "bfloat16", False), + (False, True, "bfloat16", False), + (True, False, "fp8", False), + (False, True, "fp8", False), + # LoRA tests + (False, False, "bfloat16", True), + (True, False, "bfloat16", True), ], ) async def test_vllm_generation_with_hf_training_colocated( - cluster, tokenizer, async_engine, cpu_offload, vllm_precision + cluster, tokenizer, async_engine, cpu_offload, vllm_precision, enable_lora ): """This test validates that DTensor policy can work together with colocated vLLM policy.""" @@ -897,6 +921,8 @@ async def test_vllm_generation_with_hf_training_colocated( vllm_config = deepcopy(basic_vllm_test_config) vllm_config["vllm_cfg"]["async_engine"] = async_engine vllm_config["vllm_cfg"]["precision"] = vllm_precision + vllm_config["vllm_cfg"]["lora_cfg"] = deepcopy(basic_lora_test_config) + vllm_config["vllm_cfg"]["lora_cfg"]["enabled"] = enable_lora vllm_config = configure_generation_config(vllm_config, tokenizer) vllm_policy = VllmGeneration(cluster, vllm_config) @@ -906,6 +932,9 @@ async def test_vllm_generation_with_hf_training_colocated( print("Creating DTensor policy...") dtensor_config = deepcopy(basic_dtensor_test_config) dtensor_config["dtensor_cfg"]["cpu_offload"] = cpu_offload + dtensor_config["dtensor_cfg"]["_v2"] = enable_lora + dtensor_config["dtensor_cfg"]["lora_cfg"] = deepcopy(basic_lora_test_config) + dtensor_config["dtensor_cfg"]["lora_cfg"]["enabled"] = enable_lora dtensor_config["train_global_batch_size"] = 4 lm_policy = Policy(cluster, dtensor_config, tokenizer) @@ -916,23 +945,37 @@ async def test_vllm_generation_with_hf_training_colocated( # Test await run_hf_train_process( - lm_policy, vllm_policy, tokenizer, async_engine, True, vllm_precision + lm_policy, + vllm_policy, + tokenizer, + async_engine, + True, + vllm_precision, + enable_lora, ) @pytest.mark.timeout(300) @pytest.mark.asyncio @pytest.mark.parametrize( - ("async_engine", "cpu_offload", "vllm_precision"), + ("async_engine", "cpu_offload", "vllm_precision", "enable_lora"), [ - (True, False, "bfloat16"), - (False, True, "bfloat16"), - (True, False, "fp8"), - (False, True, "fp8"), + (True, False, "bfloat16", False), + (False, True, "bfloat16", False), + (True, False, "fp8", False), + (False, True, "fp8", False), + # LoRA tests + (False, False, "bfloat16", True), + (True, False, "bfloat16", True), ], ) async def test_vllm_generation_with_hf_training_non_colocated( - policy_cluster_separate, tokenizer, async_engine, cpu_offload, vllm_precision + policy_cluster_separate, + tokenizer, + async_engine, + cpu_offload, + vllm_precision, + enable_lora, ): # Skip the fp8 tests if the GPU is not H100 or newer (compute capability < 9.0) if vllm_precision == "fp8": @@ -948,19 +991,30 @@ async def test_vllm_generation_with_hf_training_non_colocated( # Create VllmGeneration Policy print("Creating vLLM policy...") vllm_config = deepcopy(basic_vllm_test_config) + vllm_config["vllm_cfg"]["lora_cfg"] = deepcopy(basic_lora_test_config) vllm_config["vllm_cfg"]["async_engine"] = async_engine vllm_config["vllm_cfg"]["precision"] = vllm_precision + vllm_config["vllm_cfg"]["lora_cfg"]["enabled"] = enable_lora vllm_config["colocated"]["enabled"] = False + if vllm_precision == "fp8": + vllm_config["vllm_cfg"]["kv_cache_dtype"] = "fp8" vllm_config = configure_generation_config(vllm_config, tokenizer) vllm_policy = VllmGeneration(generation_cluster_separate, vllm_config) vllm_policy.finish_generation() + assert not (enable_lora and vllm_precision == "fp8"), ( + "LoRA is not supported with FP8" + ) # Create Policy print("Creating DTensor policy...") dtensor_config = deepcopy(basic_dtensor_test_config) dtensor_config["generation"]["colocated"]["enabled"] = False dtensor_config["dtensor_cfg"]["cpu_offload"] = cpu_offload dtensor_config["train_global_batch_size"] = 4 + # lora must use dtensor v2 + dtensor_config["dtensor_cfg"]["_v2"] = enable_lora + dtensor_config["dtensor_cfg"]["lora_cfg"] = deepcopy(basic_lora_test_config) + dtensor_config["dtensor_cfg"]["lora_cfg"]["enabled"] = enable_lora lm_policy = Policy(policy_cluster_separate, dtensor_config, tokenizer) # Refit @@ -983,7 +1037,13 @@ async def test_vllm_generation_with_hf_training_non_colocated( # Test await run_hf_train_process( - lm_policy, vllm_policy, tokenizer, async_engine, False, vllm_precision + lm_policy, + vllm_policy, + tokenizer, + async_engine, + False, + vllm_precision, + enable_lora, )