diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index ec8c2c5ecc..3d28bb20cc 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -90,6 +90,18 @@ policy: tensor_parallel_size: 1 context_parallel_size: 1 custom_parallel_plan: null + # LoRA (Low-Rank Adaptation) Configuration + lora_cfg: + enabled: False # Set to True to enable LoRA fine-tuning + target_modules: [] # List of module names to apply LoRA (empty list with match_all_linear=true applies to all linear layers) + exclude_modules: [] # List of module names to exclude from LoRA + match_all_linear: true # If True, applies LoRA to all linear layers (overrides target_modules) + dim: 8 # LoRA rank (r): lower rank = fewer parameters but less capacity. Typical values: 4, 8, 16, 32, 64 + alpha: 32 # LoRA scaling factor: effective learning rate multiplier = alpha/dim. Typical values: 16, 32, 64 + dropout: 0.0 # Dropout probability applied to LoRA layers (0.0 = no dropout) + dropout_position: "post" # Where to apply dropout: "pre" (before LoRA) or "post" (after LoRA) + lora_A_init: "xavier" # Initialization method for LoRA A matrix: "xavier" or "uniform" + use_triton: true # Use Triton-optimized kernels for LoRA (faster but requires flash-attn). Disable when tensor_parallel_size > 1 megatron_cfg: enabled: false diff --git a/examples/configs/recipes/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.yaml b/examples/configs/recipes/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.yaml new file mode 100644 index 0000000000..f5dc334359 --- /dev/null +++ b/examples/configs/recipes/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.yaml @@ -0,0 +1,25 @@ +defaults: ../../grpo_math_1B.yaml +grpo: + val_at_start: true +checkpointing: + checkpoint_dir: results/grpo-qwen3-8B-base-1n8g-fsdp2-lora +policy: + model_name: Qwen/Qwen3-8B-Base + max_total_sequence_length: 2048 + dtensor_cfg: + activation_checkpointing: true + lora_cfg: + enabled: True + dim: 128 + alpha: 128 + sequence_packing: + enabled: false +logger: + log_dir: logs/grpo-qwen3-8B-base-1n8g-fsdp2-lora + wandb_enabled: true + tensorboard_enabled: true + wandb: + project: nemo-rl + name: grpo-qwen3-8B-base-1n8g-fsdp2-lora +cluster: + gpus_per_node: 8 diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 7597b33f95..92ccc610e7 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -464,6 +464,23 @@ def setup( ) policy_config["megatron_cfg"]["train_iters"] = total_train_iters + if "dtensor_cfg" in policy_config and policy_config["dtensor_cfg"]["enabled"]: + lora_cfg = ( + policy_config["dtensor_cfg"]["lora_cfg"] + if "lora_cfg" in policy_config["dtensor_cfg"] + else None + ) + if "enabled" in lora_cfg and lora_cfg["enabled"]: + # Override the vLLM lora config with the DTensor lora config + generation_config["vllm_cfg"]["lora_cfg"] = lora_cfg + + assert colocated_inference, ( + "LoRA in DTensor backend is only supported with colocated inference." + ) + assert not _should_use_async_rollouts(master_config), ( + "Async rollouts are not supported with LoRA in DTensor backend." + ) + # Define initialization functions that will be used in all paths def init_policy(): """Initialize policy training workers.""" @@ -505,6 +522,9 @@ def init_vllm(): assert loss_config["use_importance_sampling_correction"] is True, ( "Importance sampling must be enabled for vLLM FP8 generation for good convergence!" ) + assert not policy_config["dtensor_cfg"]["lora_cfg"]["enabled"], ( + "LoRA is not supported with vLLM FP8 generation." + ) if generation_config["vllm_cfg"]["kv_cache_dtype"].startswith("fp8"): # FP8 KV cache requires FP8 model precision assert generation_config["vllm_cfg"]["precision"] == "fp8", ( @@ -933,18 +953,11 @@ def refit_policy_generation( timer: Optional Timer used to time the prepare/transfer/update phase kv_scales: Optional dictionary of KV cache scales for FP8 quantization. """ - if colocated_inference: - policy.offload_before_refit() - policy_generation.prepare_for_generation(tags=["weights"]) - # Create a context manager that does nothing when timer is None - timer_context = ( - timer.time("prepare_for_generation/transfer_and_update_weights") - if timer is not None - else nullcontext() - ) - with timer_context: - # update weights + def _perform_refit_weights(refit_mode: str): + assert refit_mode in ("base_model", "lora"), ( + "refit_mode must be either 'base_model' or 'lora'" + ) update_success = False if colocated_inference: # get model param keys, which is grouped by size @@ -959,9 +972,13 @@ def refit_policy_generation( ) futures_train = policy.stream_weights_via_ipc_zmq( - buffer_size_bytes=buffer_size_bytes, kv_scales=kv_scales + buffer_size_bytes=buffer_size_bytes, + kv_scales=kv_scales, + refit_mode=refit_mode, + ) + futures_inference = policy_generation.update_weights_via_ipc_zmq( + refit_mode=refit_mode, ) - futures_inference = policy_generation.update_weights_via_ipc_zmq() # wait for all futures to complete ray.get(futures_train) results = ray.get(futures_inference) @@ -985,6 +1002,35 @@ def refit_policy_generation( ) raise RuntimeError(error_message) + lora_enabled, lora_base_refit_done = policy.check_lora_base_refit_done() + refit_lora_weights = lora_enabled + refit_base_model_weights = not lora_enabled or not lora_base_refit_done + + if colocated_inference: + policy.offload_before_refit() + policy_generation.prepare_for_generation(tags=["weights"]) + + # Create a context manager that does nothing when timer is None + timer_context = ( + timer.time("prepare_for_generation/transfer_and_update_weights") + if timer is not None + else nullcontext() + ) + with timer_context: + if refit_base_model_weights: + _perform_refit_weights(refit_mode="base_model") + print( + " ▶ Refitting base model weights...", + flush=True, + ) + + if refit_lora_weights: + _perform_refit_weights(refit_mode="lora") + print( + " ▶ Refitting LoRA weights...", + flush=True, + ) + if colocated_inference: policy.offload_after_refit() policy_generation.prepare_for_generation(tags=["kv_cache"]) diff --git a/nemo_rl/models/generation/interfaces.py b/nemo_rl/models/generation/interfaces.py index d134027bdf..ca74c7ab1a 100644 --- a/nemo_rl/models/generation/interfaces.py +++ b/nemo_rl/models/generation/interfaces.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, NotRequired, TypedDict, Union +from typing import Any, NotRequired, Optional, TypedDict, Union import ray import torch @@ -245,7 +245,9 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: """Prepare the info for refit.""" raise NotImplementedError - def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]: + def update_weights_via_ipc_zmq( + self, refit_mode: Optional[str] = "base_model" + ) -> list[ray.ObjectRef]: """Update the model weights from the given IPC handles.""" raise NotImplementedError diff --git a/nemo_rl/models/generation/vllm/lora.py b/nemo_rl/models/generation/vllm/lora.py new file mode 100644 index 0000000000..a2f59c9439 --- /dev/null +++ b/nemo_rl/models/generation/vllm/lora.py @@ -0,0 +1,211 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Optional + +import torch +import vllm +from torch import nn +from vllm.lora.peft_helper import PEFTHelper +from vllm.lora.request import LoRARequest +from vllm.lora.utils import get_adapter_absolute_path +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import LinearBase + + +class LoRARequestWithCfgAndWeights(LoRARequest): + lora_cfg: Optional[dict] = None + lora_weights: Optional[dict[str, Any]] = None + + +def patched_load_adapter(self, lora_request: LoRARequestWithCfgAndWeights): + try: + supported_lora_modules = self._adapter_manager.supported_lora_modules + packed_modules_mapping = self._adapter_manager.packed_modules_mapping + expected_lora_lst: list[str] = [] + for module in supported_lora_modules: + if module in packed_modules_mapping: + expected_lora_lst.extend(packed_modules_mapping[module]) + else: + expected_lora_lst.append(module) + if module == "experts": + expected_lora_lst.append(module) + expected_lora_modules = set(expected_lora_lst) + lora_weights = None + + if isinstance(lora_request, LoRARequestWithCfgAndWeights): + lora_cfg = lora_request.lora_cfg + lora_weights = lora_request.lora_weights + peft_helper = PEFTHelper.from_dict(lora_cfg) + else: + lora_path = get_adapter_absolute_path(lora_request.lora_path) + + peft_helper = PEFTHelper.from_local_dir( + lora_path, + self.max_position_embeddings, + lora_request.tensorizer_config_dict, + ) + + # Validates the LoRA configuration against requirements before + # loading weights, throwing an exception if validation fails. + peft_helper.validate_legal(self.lora_config) + + # For some models like Qwen2VL, we need to use hf_to_vllm_mapper + # to ensure correct loading of lora weights. + model = self._adapter_manager.model + hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None) + if isinstance(lora_request, LoRARequestWithCfgAndWeights): + lora = self._lora_model_cls.from_lora_tensors( + lora_model_id=lora_request.lora_int_id, + tensors=lora_weights, + peft_helper=peft_helper, + device="cpu", + dtype=self.lora_config.lora_dtype, + embeddings=None, + target_embedding_padding=self.vocab_size + + self.lora_config.lora_extra_vocab_size, + embedding_modules=self.embedding_modules, + embedding_padding_modules=self.embedding_padding_modules, + weights_mapper=hf_to_vllm_mapper, + ) + else: + lora = self._lora_model_cls.from_local_checkpoint( + lora_path, + expected_lora_modules, + peft_helper=peft_helper, + lora_model_id=lora_request.lora_int_id, + device="cpu", + dtype=self.lora_config.lora_dtype, + target_embedding_padding=self.vocab_size + + self.lora_config.lora_extra_vocab_size, + embedding_modules=self.embedding_modules, + embedding_padding_modules=self.embedding_padding_modules, + weights_mapper=hf_to_vllm_mapper, + ) + + except FileNotFoundError as e: + # FileNotFoundError should be raised if both + # - No adapter found to download from huggingface (or in + # offline mode) + # - No local adapter files found at `lora_request.lora_path` + # For NotFoundError + raise ValueError( + f"Loading lora {lora_request.lora_name} failed: No adapter " + f"found for {lora_request.lora_path}" + ) from e + except Exception as e: + # For BadRequestError + raise e + + if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: + raise ValueError( + f"LoRA added vocab size {lora.extra_vocab_size} is greater than lora_extra_vocab_size " + f"{self.lora_config.lora_extra_vocab_size}." + ) + return lora + + +def patched_get_supported_lora_modules(model: nn.Module) -> list[str]: + """Skip lm_head modules in the supported_lora_modules. + + In vLLM, all linear layers support LoRA. But in Automodel, lm_head not support LoRA. + Refer to https://github.com/NVIDIA-NeMo/Automodel/blob/50253d14c2aefa2206036022b4ccce9f3476ba4d/nemo_automodel/components/_peft/module_matcher.py#L99 for more details. + """ + supported_lora_modules: set[str] = set() + for name, module in model.named_modules(): + # get the embedding modules if the module's embedding_modules + # is not empty. + embedding_modules = getattr(module, "embedding_modules", None) + if embedding_modules is not None: + for name in embedding_modules: + if "lm_head" in name: + continue + supported_lora_modules.add(name) + + # get all the linear subfixes. + if isinstance(module, (LinearBase,)): + supported_lora_modules.add(name.split(".")[-1]) + + if isinstance(module, (FusedMoE,)): + supported_lora_modules.add(name.split(".")[-1]) + + return list(supported_lora_modules) + + +def apply_lora_patches(): + # patch the get_supported_lora_modules function + import vllm.lora.utils as lora_utils + + setattr( + lora_utils, "get_supported_lora_modules", patched_get_supported_lora_modules + ) + + # patch the get_supported_lora_modules function in lora_models + import vllm.lora.models as lora_models + + setattr( + lora_models, "get_supported_lora_modules", patched_get_supported_lora_modules + ) + + assert vllm.__version__.startswith("0.11."), ( + "vLLM version must be == 0.11.x to apply the patches. " + "If this assertion fails, please check the vLLM version and remove the patching on condition. " + "You can:\n" + "1. Check whether vllm support load lora from memory.\n" + "2. If yes, remove the patching call\n" + "3. Delete this assertion" + "4. Delete this patch: patched_load_adapter" + ) + # patch the load_adapter function in LRUCacheWorkerLoRAManager + from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager + + setattr(LRUCacheWorkerLoRAManager, "_load_adapter", patched_load_adapter) + + +def apply_weight_name_mapping( + weights: list[tuple[str, torch.Tensor]], + supported_modules: list[str], + packed_modules_mapping: dict[str, list[str]], +) -> list[tuple[str, torch.Tensor]]: + """Apply weight name mapping if LoRA is enabled.""" + + def map_param_name(param_name: str) -> str: + # Vllm add logits_processor to lm_head weight(https://github.com/vllm-project/vllm/blob/b8b302cde434df8c9289a2b465406b47ebab1c2d/vllm/lora/models.py#L506), we skip mapping for lm_head weight + if "lm_head" in param_name: + return param_name + parts = param_name.split(".") + if len(parts) < 2: + return param_name + base_name = ".".join(parts[:-2]) # prefix + module_name = parts[-2] # e.g. q_proj/k_proj/v_proj/gate_proj/up_proj/... + field_name = parts[-1] # weight/bias + resolved_module_name = module_name + for packed_name, member_names in packed_modules_mapping.items(): + if module_name in member_names: + resolved_module_name = packed_name + break + # use resolved_module_name for checking, but return the original module_name + if resolved_module_name in supported_modules: + if base_name != "": + return f"{base_name}.{module_name}.base_layer.{field_name}" + else: + return f"{module_name}.base_layer.{field_name}" + return param_name + + new_weights = [] + for name, w in weights: + new_name = map_param_name(name) + new_weights.append((new_name, w)) + return new_weights diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index e342139d59..5e39681731 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -13,7 +13,7 @@ # limitations under the License. import gc import traceback -from typing import Any +from typing import Any, Optional import torch import zmq @@ -87,7 +87,13 @@ def maybe_init_zmq(self): self.zmq_socket.setsockopt(zmq.LINGER, 0) self.zmq_socket.connect(self.get_zmq_address()) - def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: + def prepare_refit_info( + self, + state_dict_info: dict[str, Any], + lora_enabled: bool, + lora_metadata: Optional[dict[str, Any]] = None, + lora_cfg_dict: Optional[dict[str, Any]] = None, + ) -> None: """Prepare state dict metadata for weight refitting and IPC streaming. Args: @@ -95,6 +101,9 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: e.g. {tensor_name: (shape, dtype)} """ self.state_dict_info = state_dict_info # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored + self.lora_enabled = lora_enabled # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored + self.lora_metadata = lora_metadata # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored + self.lora_cfg_dict = lora_cfg_dict # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored def _maybe_process_fp8_kv_cache(self) -> None: """Process weights after loading for FP8 KV cache (static scales).""" @@ -125,8 +134,82 @@ def _maybe_process_fp8_kv_cache(self) -> None: target_device, ) + def _apply_loaded_weights( + self, + weights: list[tuple[str, torch.Tensor]], + refit_mode: Optional[str] = "base_model", + ) -> None: + """Apply loaded weights to model or LoRA based on flags. + + This unifies the duplicate logic used by both IPC and collective paths. + """ + assert refit_mode in ["base_model", "lora"], ( + f"refit_mode must be 'base_model' or 'lora', but got {refit_mode}" + ) + + from nemo_rl.models.generation.vllm.quantization import fp8 + + runner = self.model_runner + + if fp8.is_fp8_model(runner.vllm_config): + assert refit_mode == "base_model", ( + "fp8 model only supports base_model refit mode" + ) + # the fp8 load_weights additionally casts bf16 weights into fp8 + fp8.load_weights(weights, runner) + return + + if refit_mode == "base_model": + if self.lora_enabled: + from nemo_rl.models.generation.vllm.lora import ( + apply_weight_name_mapping, + ) + + lora_mgr = self.model_runner.model.lora_manager + supported_modules = lora_mgr.supported_lora_modules + packed_modules_mapping = lora_mgr.packed_modules_mapping + + weights = apply_weight_name_mapping( + weights, supported_modules, packed_modules_mapping + ) + runner.model.load_weights(weights=weights) + return + + if refit_mode == "lora": + from nemo_rl.models.generation.vllm.lora import LoRARequestWithCfgAndWeights + + if self.lora_metadata is None or self.lora_cfg_dict is None: + raise ValueError( + "LoRA metadata/config must be set before LoRA refit mode." + ) + lora_metadata = self.lora_metadata + lora_cfg_dict = self.lora_cfg_dict + # Note: We don't need to remove the lora if it is already set max_loras = 1 + self.remove_lora(lora_id=lora_metadata["lora_int_id"]) + lora_request = LoRARequestWithCfgAndWeights( + **lora_metadata, + lora_cfg=lora_cfg_dict, + lora_weights=dict({name: tensor for name, tensor in weights}), + ) + try: + self.add_lora(lora_request=lora_request) + except Exception as e: + print( + f"Error in VllmInternalWorkerExtension._apply_loaded_weights: {e}" + ) + print(traceback.format_exc()) + raise e + # self.add_lora(lora_request=lora_request) + return + + raise ValueError( + f"refit_mode must be 'base_model' or 'lora', but got {refit_mode}" + ) + @wrap_with_nvtx_name("vllm_internal_worker_extension/update_weights_via_ipc_zmq") - def update_weights_via_ipc_zmq(self) -> bool: + def update_weights_via_ipc_zmq( + self, refit_mode: Optional[str] = "base_model" + ) -> bool: """Receive and update model weights via ZMQ IPC socket. Returns: @@ -176,14 +259,11 @@ def update_weights_via_ipc_zmq(self) -> bool: assert offset == used_bytes, ( "Offset is not equal to used bytes, usually indicate inaccurate info like keys or cached dtype in state_dict_info" ) - # Load weights into the model - from nemo_rl.models.generation.vllm.quantization import fp8 - - if fp8.is_fp8_model(self.model_runner.vllm_config): - # the fp8 load_weights additionally casts bf16 weights into fp8 - fp8.load_weights(weights, self.model_runner) - else: - self.model_runner.model.load_weights(weights=weights) + # Load weights into the model or LoRA + self._apply_loaded_weights( + weights=weights, + refit_mode=refit_mode, + ) torch.cuda.current_stream().synchronize() diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index 93540ebe82..cbe91dc443 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -767,7 +767,9 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: # Wait for all futures to complete ray.get(futures) - def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]: + def update_weights_via_ipc_zmq( + self, refit_mode: Optional[str] = "base_model" + ) -> list[ray.ObjectRef]: """Update weights of the policy using IPC handles via ZMQ socket.""" if not self.worker_group or not self.worker_group.workers: raise RuntimeError("Worker group is not initialized") @@ -783,6 +785,7 @@ def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]: futures = self.worker_group.run_all_workers_single_data( method_name, run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], + refit_mode=refit_mode, ) # this function should co-work with lm_policy, so we should wait for all futures to complete outside diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index 9238533cd2..8f14e49a27 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -139,6 +139,7 @@ def __init__( self.enable_expert_parallel = self.expert_parallel_size > 1 self.gpu_memory_utilization = self.cfg["vllm_cfg"]["gpu_memory_utilization"] self.precision = self.cfg["vllm_cfg"]["precision"] + self.lora_cfg = self.cfg["vllm_cfg"].get("lora_cfg", None) self.fraction_of_gpus = fraction_of_gpus self.is_model_owner = bundle_indices is not None @@ -397,6 +398,41 @@ def _patch_vllm_vit_flash_attn_backend(): ) self.cfg["vllm_cfg"]["skip_tokenizer_init"] = False + # Lora is enabled, add it to the vllm kwargs + self.lora_enabled = False + self.lora_metadata = None + self.lora_cfg_dict = None + if self.lora_cfg is not None and self.lora_cfg["enabled"]: + try: + from nemo_rl.models.generation.vllm.lora import apply_lora_patches + + apply_lora_patches() + + except Exception as e: + raise RuntimeError( + f"Lora is enabled, but failed to apply lora patches: {e}" + ) + + self.lora_enabled = True + self.lora_metadata = dict( + { + "lora_int_id": 1, # Can be any unique id exclude 0 + "lora_name": "lora_1", + "lora_path": "dummy_lora_path", + } + ) + self.lora_cfg_dict = dict( + { + "r": self.lora_cfg["dim"], + "lora_alpha": self.lora_cfg["alpha"], + "target_modules": self.lora_cfg["target_modules"], + } + ) + + vllm_kwargs["enable_lora"] = True + vllm_kwargs["max_loras"] = 1 # only support one lora adapter + vllm_kwargs["max_lora_rank"] = self.lora_cfg["dim"] + llm_kwargs = dict( model=self.model_name, served_model_name=self.model_name, @@ -568,7 +604,15 @@ def generate( assert self.llm is not None, ( "Attempting to generate with either an uninitialized vLLM or non-model-owner" ) - outputs = self.llm.generate(prompts, sampling_params) + + lora_req = None + if self.lora_enabled: + from vllm.lora.request import LoRARequest + + lora_req = LoRARequest( + **self.lora_metadata, + ) + outputs = self.llm.generate(prompts, sampling_params, lora_request=lora_req) # Process the outputs - but preserve the original input padding structure output_ids_list = [] @@ -697,7 +741,17 @@ def generate_text( assert self.llm is not None, ( "Attempting to generate with either an uninitialized vLLM or non-model-owner" ) - outputs = self.llm.generate(data["prompts"], sampling_params) + + lora_req = None + if self.lora_enabled: + from vllm.lora.request import LoRARequest + + lora_req = LoRARequest( + **self.lora_metadata, + ) + outputs = self.llm.generate( + data["prompts"], sampling_params, lora_request=lora_req + ) texts = [output.outputs[0].text for output in outputs] # Convert to BatchedDataDict @@ -724,10 +778,21 @@ def report_device_id(self) -> list[str]: def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: """Prepare the info for refit.""" - self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,)) + self.llm.collective_rpc( + "prepare_refit_info", + args=( + state_dict_info, + self.lora_enabled, + self.lora_metadata, + self.lora_cfg_dict, + ), + ) @wrap_with_nvtx_name("vllm_genertion_worker/update_weights_via_ipc_zmq") - def update_weights_via_ipc_zmq(self) -> bool: + def update_weights_via_ipc_zmq( + self, + refit_mode: Optional[str] = "base_model", + ) -> bool: """Update weights from IPC handles via ZMQ socket.""" try: assert self.llm is not None, ( @@ -741,7 +806,7 @@ def update_weights_via_ipc_zmq(self) -> bool: result_or_coro = self.llm.collective_rpc( "update_weights_via_ipc_zmq", - args=tuple(), + args=(refit_mode,), ) worker_result = result_or_coro[0] diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index 0e4ea5cdeb..f7c6f490e7 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -984,10 +984,19 @@ async def report_device_id_async(self) -> list[str]: async def prepare_refit_info_async(self, state_dict_info: dict[str, Any]) -> None: """Async version of prepare_refit_info.""" - await self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,)) + await self.llm.collective_rpc( + "prepare_refit_info", + args=( + state_dict_info, + self.lora_enabled, + self.lora_metadata, + self.lora_cfg_dict, + ), + ) async def update_weights_via_ipc_zmq_async( self, + refit_mode: Optional[str] = "base_model", ) -> bool: """Async version of update_weights_via_ipc_zmq.""" try: @@ -1002,7 +1011,8 @@ async def update_weights_via_ipc_zmq_async( # TODO: switch to update_weights_from_local_ipc_handles for better performance once collectively report_device_id is supported in asyncLLM initialization result_or_coro = await self.llm.collective_rpc( - "update_weights_via_ipc_zmq", args=tuple() + "update_weights_via_ipc_zmq", + args=(refit_mode,), ) if asyncio.iscoroutine(result_or_coro): diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 144683c95c..59b3387d59 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -86,6 +86,9 @@ def __init__( "Configure either Megatron (policy.megatron_cfg.enabled=true) or " "DTensor (policy.dtensor_cfg.enabled=true), not both." ) + # Default to False, will be overridden if LoRA is enabled + self.lora_enabled = False + if megatron_enable: worker_builder_cls = "nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker" tp_size = config["megatron_cfg"]["tensor_model_parallel_size"] @@ -109,6 +112,9 @@ def __init__( # Check if _v2 is enabled in dtensor_cfg (defaults to False for backward compatibility) use_v2 = config.get("dtensor_cfg", {}).get("_v2", False) + lora_cfg = config.get("dtensor_cfg", {}).get("lora_cfg", {}) + self.lora_enabled = lora_cfg.get("enabled", False) + if use_v2: worker_builder_cls = "nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2" @@ -118,10 +124,9 @@ def __init__( "if you are running a custom container or baremetal, you may need to set this variable manually. Example: export TORCH_CUDA_ARCH_LIST='9.0 10.0'" ) else: - assert ( - config["dtensor_cfg"].get("lora_cfg", {}).get("enabled", False) - is False - ), "LoRA is not supported for DTensorPolicyWorker V1" + assert not self.lora_enabled, ( + "LoRA is not supported for DTensorPolicyWorker V1" + ) worker_builder_cls = "nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker" tp_size = config["dtensor_cfg"]["tensor_parallel_size"] @@ -758,13 +763,17 @@ def get_free_memory_bytes(self) -> int: return free_memory_bytes def stream_weights_via_ipc_zmq( - self, buffer_size_bytes: int, kv_scales: Optional[dict[str, float]] = None + self, + buffer_size_bytes: int, + kv_scales: Optional[dict[str, float]] = None, + refit_mode: Optional[str] = "base_model", ) -> list[ray.ObjectRef]: """Send the weights for IPC handles via ZMQ socket.""" futures = self.worker_group.run_all_workers_single_data( "stream_weights_via_ipc_zmq", buffer_size_bytes=buffer_size_bytes, kv_scales=kv_scales, + refit_mode=refit_mode, ) return futures @@ -885,3 +894,23 @@ def print_node_ip_and_gpu_id(self) -> list[tuple[str, int]]: table.add_row(row) print(table) + + def check_lora_base_refit_done(self) -> bool: + """Check if the base model weights have been refit.""" + dtensor_cfg = self.cfg.get("dtensor_cfg", {}) + is_dtensor = dtensor_cfg.get("enabled", False) + is_v2 = is_dtensor and dtensor_cfg.get("_v2", False) + is_megatron = self.cfg.get("megatron_cfg", {}).get("enabled", False) + + # Only DTensor v2 with LoRA supports lora GRPO workflow + lora_enabled = is_v2 and dtensor_cfg.get("lora_cfg", {}).get("enabled", False) + + if is_megatron or (is_dtensor and not is_v2) or (not lora_enabled): + return False, False + + # Check if the base model weights have been refit only when LoRA is enabled and DTensor v2 is used + futures = self.worker_group.run_all_workers_single_data( + "get_lora_base_refit_done" + ) + results = ray.get(futures) + return True, all(results) diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker.py b/nemo_rl/models/policy/workers/dtensor_policy_worker.py index e9c57cfb55..5b12161a82 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker.py @@ -1680,8 +1680,13 @@ def stream_weights_via_ipc_zmq( self, buffer_size_bytes: int = 0, kv_scales: Optional[dict[str, float]] = None, + refit_mode: Optional[str] = "base_model", ) -> None: """Stream model weights to peer process via ZMQ IPC socket.""" + assert refit_mode == "base_model", ( + f"refit_mode must be 'base_model' in dtensor v1, but got {refit_mode}" + ) + if kv_scales is not None: raise NotImplementedError( "FP8 kvcache is not currently supported for DTensor path, we will support it in the future." diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 49e1360c57..e1583a7c25 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -97,6 +97,7 @@ from nemo_rl.utils.checkpoint import CheckpointingConfig from nemo_rl.utils.nsys import wrap_with_nvtx_name from nemo_rl.utils.packed_tensor import packed_broadcast_producer +from nemo_rl.utils.weights import is_base_model_weight_name, is_lora_weight_name STRING_TO_DTYPE = { "float32": torch.float32, @@ -106,7 +107,9 @@ def dtensor_params_generator( - model: nn.Module, target_dtype: torch.dtype + model: nn.Module, + target_dtype: torch.dtype, + refit_mode: Optional[str] = "base_model", ) -> Generator[tuple[str, torch.Tensor], None, None]: """Generator that yields (name, tensor) pairs, converting DTensors to local tensors and adapting to HF format. @@ -117,7 +120,15 @@ def dtensor_params_generator( Yields: Tuples of (fully_qualified_name, tensor) where tensors are converted to target dtype and made contiguous. """ + assert refit_mode in ["base_model", "lora"], ( + f"refit_mode must be 'base_model' or 'lora', but got {refit_mode}" + ) for name, tensor in model.state_dict().items(): + if is_base_model_weight_name(name) and refit_mode != "base_model": + continue + if is_lora_weight_name(name) and refit_mode != "lora": + continue + full_tensor = tensor.full_tensor() if isinstance(tensor, DTensor) else tensor adapted_fqn_tensors = _maybe_adapt_tensor_to_hf(model, name, full_tensor) for adapted_fqn, adapted_tensor in adapted_fqn_tensors: @@ -329,6 +340,8 @@ def __init__( # autocast should cast the weights to the correct dtype during the forward pass. cfg_dict_with_dtype = {**lora_cfg, "lora_dtype": "torch.float32"} self.peft_config = PeftConfig.from_dict(cfg_dict_with_dtype) + # Track if the base model weights have been refit, used only for LoRA. + self.lora_base_refit_done = False print(f"[Rank {self.rank}] Initializing empty model for FSDP...") # All ranks initialize model on meta device, so FSDP can shard it. @@ -385,6 +398,16 @@ def __init__( ) if self.lora_enabled: apply_lora_to_linear_modules(self.model, self.peft_config) + if hasattr(self.model, "lm_head"): + assert not hasattr(self.model.lm_head, "lora_A"), ( + "lm_head should not be patched with LoRA adapters. " + "If this assertion fails, the upstream bug has been fixed in Automodel. " + "You can:\n" + "1. Remove the patch patched_get_supported_lora_modules in nemo_rl/models/generation/vllm/lora.py\n" + "2. Remove the patching call\n" + "3. Retest the reward in train and accuracy in validation at the first step should be exactly equal for Llama3.2-3B-Instruct model.\n" + "4. Delete this assertion" + ) # For activation checkpointing, we also must globally disable the cudnn SDPA backend # to ensure that cudnn does not get selected during recomputation. @@ -1809,6 +1832,7 @@ def stream_weights_via_ipc_zmq( self, buffer_size_bytes: int = 0, kv_scales: Optional[dict[str, float]] = None, + refit_mode: Optional[str] = "base_model", ) -> None: """Stream model weights to peer process via ZMQ IPC socket.""" if kv_scales is not None: @@ -1816,6 +1840,12 @@ def stream_weights_via_ipc_zmq( "FP8 kvcache is not currently supported for DTensor path, we will support it in the future." ) + if refit_mode == "base_model" and self.lora_enabled: + assert not self.lora_base_refit_done, ( + "Base model weights have already been refit, cannot refit again" + ) + self.lora_base_refit_done = True + self.maybe_init_zmq() # Manually move model to cuda for cpu offload case if self.cpu_offload: @@ -1825,7 +1855,9 @@ def stream_weights_via_ipc_zmq( # Use the shared implementation stream_weights_via_ipc_zmq_impl( - params_generator=dtensor_params_generator(self.model, self.dtype), + params_generator=dtensor_params_generator( + self.model, self.dtype, refit_mode=refit_mode + ), buffer_size_bytes=buffer_size_bytes, zmq_socket=self.zmq_socket, rank=self.rank, @@ -2026,3 +2058,7 @@ def _init_checkpoint_manager( config_updates=config_updates, checkpoint_root=checkpoint_root, ) + + def get_lora_base_refit_done(self) -> bool: + """Get if the base model weights have been refit.""" + return self.lora_base_refit_done diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 1d175f35b2..687b336156 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -2186,7 +2186,10 @@ def _iter_params_with_optional_kv_scales( @torch.no_grad() @wrap_with_nvtx_name("megatron_policy_worker/stream_weights_via_ipc_zmq") def stream_weights_via_ipc_zmq( - self, buffer_size_bytes: int = 0, kv_scales: Optional[dict[str, float]] = None + self, + buffer_size_bytes: int = 0, + kv_scales: Optional[dict[str, float]] = None, + refit_mode: Optional[str] = "base_model", ) -> None: """Stream model weights to peer process via ZMQ IPC socket.""" self.maybe_init_zmq() diff --git a/nemo_rl/utils/weights.py b/nemo_rl/utils/weights.py new file mode 100644 index 0000000000..52ab2b2804 --- /dev/null +++ b/nemo_rl/utils/weights.py @@ -0,0 +1,27 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def is_lora_weight_name(name: str) -> bool: + """Return True if a parameter name corresponds to a LoRA weight.""" + return ( + name.endswith(".lora_A.weight") + or name.endswith(".lora_B.weight") + or name.endswith(".lora_scaling.weight") + ) + + +def is_base_model_weight_name(name: str) -> bool: + """Return True if a parameter name corresponds to a base (non-LoRA) weight.""" + return not is_lora_weight_name(name) diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh old mode 100644 new mode 100755 index ec7527f583..30e4a8bfbd --- a/tests/functional/L1_Functional_Tests_GPU.sh +++ b/tests/functional/L1_Functional_Tests_GPU.sh @@ -27,6 +27,7 @@ time uv run --no-sync bash ./tests/functional/sft.sh time uv run --no-sync bash ./tests/functional/sft_resume_diamond.sh time uv run --no-sync bash ./tests/functional/grpo.sh time uv run --no-sync bash ./tests/functional/grpo_async.sh +time uv run --no-sync bash ./tests/functional/grpo_automodel_lora.sh time uv run --no-sync bash ./tests/functional/grpo_megatron.sh time uv run --no-sync bash ./tests/functional/grpo_megatron_generation.sh time uv run --no-sync bash ./tests/functional/grpo_multiturn.sh diff --git a/tests/functional/grpo_automodel_lora.sh b/tests/functional/grpo_automodel_lora.sh new file mode 100755 index 0000000000..1268e5c60f --- /dev/null +++ b/tests/functional/grpo_automodel_lora.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# clean up checkpoint directory on exit +trap "rm -rf /tmp/lora_sft_checkpoints" EXIT + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_grpo_math.py\ + grpo.max_num_steps=3 \ + grpo.num_prompts_per_step=8 \ + grpo.num_generations_per_prompt=4 \ + policy.dtensor_cfg.lora_cfg.enabled=True \ + policy.dtensor_cfg.lora_cfg.dim=32 \ + policy.train_global_batch_size=32 \ + policy.train_micro_batch_size=1 \ + cluster.gpus_per_node=2 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.monitor_gpus=true \ + checkpointing.enabled=false \ + "$@" \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'data["train/reward"]["3"] > 0.06' + diff --git a/tests/functional/sft_automodel_lora.sh b/tests/functional/sft_automodel_lora.sh old mode 100644 new mode 100755 diff --git a/tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh b/tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh new file mode 100755 index 0000000000..2e2ad8beb5 --- /dev/null +++ b/tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh @@ -0,0 +1,43 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=20 +MAX_STEPS=20 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=30 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo_math.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'data["train/gen_kl_error"]["20"] < 0.001' \ + 'max(data["train/reward"]) > 0.35' \ + 'mean(data["timing/train/total_step_time"], 2) < 80' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index 4c93e4fcb9..c0eba95271 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -62,6 +62,9 @@ tests/test_suites/llm/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated.sh tests/test_suites/llm/grpo-nano-v2-12b-1n8g-megatron.sh tests/test_suites/llm/grpo-nano-v2-12b-2n8g-fsdp2tp1.sh +# lora +tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh + ####### # SFT # ####### diff --git a/tests/unit/models/dtensor/test_lora.py b/tests/unit/models/dtensor/test_lora.py index b4c92f1120..abe8216153 100644 --- a/tests/unit/models/dtensor/test_lora.py +++ b/tests/unit/models/dtensor/test_lora.py @@ -31,6 +31,7 @@ PeftConfig, apply_lora_to_linear_modules, ) +from nemo_automodel.components._peft.module_matcher import ModuleMatcher class SimpleLoraMock(nn.Module): @@ -264,3 +265,29 @@ def test_dropout_pre_post_effects(dummy_input): assert not torch.allclose(out_pre, out_post), ( "Dropout positions should affect output differently" ) + + +def test_patched_get_supported_lora_modules_needed(): + target_modules = ["*"] + exclude_modules = [] + match_all_linear = True + is_causal_lm = True + module_matcher = ModuleMatcher( + target_modules=target_modules, + exclude_modules=exclude_modules, + match_all_linear=match_all_linear, + is_causal_lm=is_causal_lm, + ) + + model = nn.Module() + model.lm_head = nn.Linear(10, 10) + is_lm_head_supported = module_matcher.match(model.lm_head, "lm_head") + assert is_lm_head_supported is False, ( + "LoRA Adapter should not be applied to lm_head. " + "If this assertion fails, the upstream bug has been fixed in Automodel. " + "You can:\n" + "1. Remove the patch patched_get_supported_lora_modules in nemo_rl/models/generation/vllm/lora.py\n" + "2. Remove the patching call\n" + "3. Retest the reward in train and accuracy in validation at the first step should be exactly equal for Llama3.2-3B-Instruct model.\n" + "4. Delete this test" + ) diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 1599b7e703..ef8aa4fa8b 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -35,7 +35,7 @@ from nemo_rl.models.generation.vllm.vllm_worker_async import ( _replace_prefix_tokens, ) -from nemo_rl.models.policy import PolicyConfig +from nemo_rl.models.policy import LoRAConfig, PolicyConfig from nemo_rl.models.policy.lm_policy import Policy model_name = "Qwen/Qwen3-0.6B" @@ -70,6 +70,7 @@ "skip_tokenizer_init": False, "load_format": "auto", "enforce_eager": "False", + "kv_cache_dtype": "auto", }, "colocated": { "enabled": True, @@ -105,6 +106,7 @@ }, }, "dtensor_cfg": { + "_v2": False, "enabled": True, "cpu_offload": False, "sequence_parallel": False, @@ -127,6 +129,19 @@ "generation": deepcopy(basic_vllm_test_config), } +basic_lora_test_config: LoRAConfig = { + "enabled": False, + "target_modules": [], + "exclude_modules": [], + "match_all_linear": True, + "dim": 8, + "alpha": 32, + "dropout": 0.0, + "dropout_position": "post", + "lora_A_init": "xavier", + "use_triton": False, +} + def get_basic_megatron_test_config( tp: int = 1, @@ -688,7 +703,13 @@ def configure_worker_fixed_seed(num_gpus, bundle_indices=None): async def run_hf_train_process( - lm_policy, vllm_policy, tokenizer, async_engine, colocated, vllm_precision + lm_policy, + vllm_policy, + tokenizer, + async_engine, + colocated, + vllm_precision, + enable_lora, ): """Validates that the two policies can work together. @@ -868,16 +889,19 @@ async def run_hf_train_process( @pytest.mark.timeout(300) @pytest.mark.asyncio @pytest.mark.parametrize( - ("async_engine", "cpu_offload", "vllm_precision"), + ("async_engine", "cpu_offload", "vllm_precision", "enable_lora"), [ - (True, False, "bfloat16"), - (False, True, "bfloat16"), - (True, False, "fp8"), - (False, True, "fp8"), + (True, False, "bfloat16", False), + (False, True, "bfloat16", False), + (True, False, "fp8", False), + (False, True, "fp8", False), + # LoRA tests + (False, False, "bfloat16", True), + (False, True, "bfloat16", True), ], ) async def test_vllm_generation_with_hf_training_colocated( - cluster, tokenizer, async_engine, cpu_offload, vllm_precision + cluster, tokenizer, async_engine, cpu_offload, vllm_precision, enable_lora ): """This test validates that DTensor policy can work together with colocated vLLM policy.""" @@ -894,6 +918,8 @@ async def test_vllm_generation_with_hf_training_colocated( vllm_config = deepcopy(basic_vllm_test_config) vllm_config["vllm_cfg"]["async_engine"] = async_engine vllm_config["vllm_cfg"]["precision"] = vllm_precision + vllm_config["vllm_cfg"]["lora_cfg"] = deepcopy(basic_lora_test_config) + vllm_config["vllm_cfg"]["lora_cfg"]["enabled"] = enable_lora vllm_config = configure_generation_config(vllm_config, tokenizer) vllm_policy = VllmGeneration(cluster, vllm_config) @@ -903,6 +929,9 @@ async def test_vllm_generation_with_hf_training_colocated( print("Creating DTensor policy...") dtensor_config = deepcopy(basic_dtensor_test_config) dtensor_config["dtensor_cfg"]["cpu_offload"] = cpu_offload + dtensor_config["dtensor_cfg"]["_v2"] = enable_lora + dtensor_config["dtensor_cfg"]["lora_cfg"] = deepcopy(basic_lora_test_config) + dtensor_config["dtensor_cfg"]["lora_cfg"]["enabled"] = enable_lora dtensor_config["train_global_batch_size"] = 4 lm_policy = Policy(cluster, dtensor_config, tokenizer) @@ -913,23 +942,34 @@ async def test_vllm_generation_with_hf_training_colocated( # Test await run_hf_train_process( - lm_policy, vllm_policy, tokenizer, async_engine, True, vllm_precision + lm_policy, + vllm_policy, + tokenizer, + async_engine, + True, + vllm_precision, + enable_lora, ) @pytest.mark.timeout(300) @pytest.mark.asyncio @pytest.mark.parametrize( - ("async_engine", "cpu_offload", "vllm_precision"), + ("async_engine", "cpu_offload", "vllm_precision", "enable_lora"), [ - (True, False, "bfloat16"), - (False, True, "bfloat16"), - (True, False, "fp8"), - (False, True, "fp8"), + (True, False, "bfloat16", False), + (False, True, "bfloat16", False), + (True, False, "fp8", False), + (False, True, "fp8", False), ], ) async def test_vllm_generation_with_hf_training_non_colocated( - policy_cluster_separate, tokenizer, async_engine, cpu_offload, vllm_precision + policy_cluster_separate, + tokenizer, + async_engine, + cpu_offload, + vllm_precision, + enable_lora, ): # Skip the fp8 tests if the GPU is not H100 or newer (compute capability < 9.0) if vllm_precision == "fp8": @@ -945,19 +985,30 @@ async def test_vllm_generation_with_hf_training_non_colocated( # Create VllmGeneration Policy print("Creating vLLM policy...") vllm_config = deepcopy(basic_vllm_test_config) + vllm_config["vllm_cfg"]["lora_cfg"] = deepcopy(basic_lora_test_config) vllm_config["vllm_cfg"]["async_engine"] = async_engine vllm_config["vllm_cfg"]["precision"] = vllm_precision + vllm_config["vllm_cfg"]["lora_cfg"]["enabled"] = enable_lora vllm_config["colocated"]["enabled"] = False + if vllm_precision == "fp8": + vllm_config["vllm_cfg"]["kv_cache_dtype"] = "fp8" vllm_config = configure_generation_config(vllm_config, tokenizer) vllm_policy = VllmGeneration(generation_cluster_separate, vllm_config) vllm_policy.finish_generation() + assert not (enable_lora and vllm_precision == "fp8"), ( + "LoRA is not supported with FP8" + ) # Create Policy print("Creating DTensor policy...") dtensor_config = deepcopy(basic_dtensor_test_config) dtensor_config["generation"]["colocated"]["enabled"] = False dtensor_config["dtensor_cfg"]["cpu_offload"] = cpu_offload dtensor_config["train_global_batch_size"] = 4 + # lora must use dtensor v2 + dtensor_config["dtensor_cfg"]["_v2"] = enable_lora + dtensor_config["dtensor_cfg"]["lora_cfg"] = deepcopy(basic_lora_test_config) + dtensor_config["dtensor_cfg"]["lora_cfg"]["enabled"] = enable_lora lm_policy = Policy(policy_cluster_separate, dtensor_config, tokenizer) # Refit @@ -980,7 +1031,13 @@ async def test_vllm_generation_with_hf_training_non_colocated( # Test await run_hf_train_process( - lm_policy, vllm_policy, tokenizer, async_engine, False, vllm_precision + lm_policy, + vllm_policy, + tokenizer, + async_engine, + False, + vllm_precision, + enable_lora, )