diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index b909ccbb5..edaf1f606 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -154,6 +154,7 @@ def post_patch_loss_function(model): pass +torch_cuda_device = torch.cuda.device def fused_linear_cross_entropy( hidden_states : torch.Tensor, lm_weight : torch.Tensor, @@ -167,16 +168,17 @@ def fused_linear_cross_entropy( # All Unsloth Zoo code licensed under LGPLv3 reduction = "sum" if num_items_in_batch is not None else "mean" if logit_softcapping == 0: logit_softcapping = None - loss = linear_cross_entropy( - hidden_states.to(lm_weight.dtype), - lm_weight, - targets = labels, - ignore_index = ignore_index, - softcap = logit_softcapping, - reduction = reduction, - shift = True, - filter_eps = accuracy_threshold, - ) + with torch_cuda_device(lm_weight.device): + loss = linear_cross_entropy( + hidden_states.to(lm_weight.dtype), + lm_weight, + targets = labels, + ignore_index = ignore_index, + softcap = logit_softcapping, + reduction = reduction, + shift = True, + filter_eps = accuracy_threshold, + ) if num_items_in_batch is not None: loss = loss / num_items_in_batch return loss pass diff --git a/unsloth_zoo/vllm_rlhf_utils.py b/unsloth_zoo/vllm_rlhf_utils.py new file mode 100644 index 000000000..92dbf67e0 --- /dev/null +++ b/unsloth_zoo/vllm_rlhf_utils.py @@ -0,0 +1,150 @@ +# Unsloth Zoo - Utilities for Unsloth +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . + +import torch +__all__ = [ + "WorkerExtension", + "ColocateWorkerExtension", +] + +def stateless_init_process_group(master_address, master_port, rank, world_size, + device): + """ + vLLM provides `StatelessProcessGroup` to create a process group + without considering the global process group in torch.distributed. + It is recommended to create `StatelessProcessGroup`, and then initialize + the data-plane communication (NCCL) between external (train processes) + and vLLM workers. + """ + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.utils import StatelessProcessGroup + pg = StatelessProcessGroup.create(host=master_address, + port=master_port, + rank=rank, + world_size=world_size) + pynccl = PyNcclCommunicator(pg, device=device) + return pynccl + + +class WorkerExtension: + """ + The class for vLLM's worker to inherit from. + By defining an extension class, the code can work no matter what is + the underlying worker class. This way, the code can be compatible + with both vLLM V0 and V1. + NOTE: we define this class in a separate module, and the main module + should pass the full qualified name as `worker_extension_cls` argument. + """ + + def init_weight_update_group(self, master_address, master_port, + rank_offset, world_size): + from vllm.distributed.parallel_state import get_world_group + rank = get_world_group().rank + rank_offset + self.model_update_group = stateless_init_process_group( + master_address, + master_port, + rank, + world_size, + self.device, + ) + + def update_weight(self, name, dtype, shape): + weight = torch.empty(shape, dtype=dtype, device="cuda") + self.model_update_group.broadcast(weight, + src=0, + stream=torch.cuda.current_stream()) + + self.model_runner.model.load_weights(weights=[(name, weight)]) + + del weight + + def check_weights_changed(self): + """ + Check if the weights are updated to 0. + """ + weights_updated = True + for name, p in self.model_runner.model.named_parameters(): + weights_updated = weights_updated and torch.allclose( + p, torch.zeros_like(p)) + return weights_updated + + +class ColocateWorkerExtension: + """ + The class for vLLM's worker to inherit from, in the colocate setting. + By defining an extension class, the code can work no matter what is + the underlying worker class. This way, the code can be compatible + with both vLLM V0 and V1. + NOTE: we define this class in a separate module, and the main module + should pass the full qualified name as `worker_extension_cls` argument. + """ + + def report_device_id(self) -> str: + from vllm.platforms import current_platform + self.device_uuid = current_platform.get_device_uuid(self.device.index) + return self.device_uuid + + def update_weights_from_ipc_handles(self, ipc_handles): + handles = ipc_handles[self.device_uuid] + device_id = self.device.index + weights = [] + for name, handle in handles.items(): + func, args = handle + list_args = list(args) + # the key is to change device id to the current device id + # in case two processes have different CUDA_VISIBLE_DEVICES + list_args[6] = device_id + tensor = func(*list_args) + weights.append((name, tensor)) + self.model_runner.model.load_weights(weights=weights) + torch.cuda.synchronize() + + def get_model_runner(self): + vllm_model = self.model_runner.model + model_loras_A, model_loras_B = [], [] + vllm_loras_A, vllm_loras_B = [], [] + parameters = [] + for v_layer in vllm_model.model.layers: + print(v_layer.self_attn.qkv_proj.lora_a_stacked[0]) + vllm_loras_A .append(v_layer.self_attn.qkv_proj.lora_a_stacked[0]) + vllm_loras_A .append(v_layer.self_attn.qkv_proj.lora_a_stacked[1]) + vllm_loras_A .append(v_layer.self_attn.qkv_proj.lora_a_stacked[2]) + + # parameters.append((name, param)) + torch.cuda.synchronize() + return vllm_loras_A + + def check_weights_changed(self): + """ + Check if the weights are updated to 0. + """ + weights_updated = True + for name, p in self.model_runner.model.named_parameters(): + weights_updated = weights_updated and torch.allclose( + p, torch.zeros_like(p)) + return weights_updated + + def get_weight_ipc_handles(self): + from torch.multiprocessing.reductions import reduce_tensor + data = {} + vllm_model = self.model_runner.model + for name, p in vllm_model.named_parameters(): + # the training actor might only have a subset of the weights + # and need to all-gather the weights from all the actors. + # for demonstration, here we assume all training actors have + # the full weights. + data[name] = reduce_tensor(p.detach()) + return {self.device_uuid: data} \ No newline at end of file diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 6cd3bdc92..f1856f260 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -394,8 +394,11 @@ def unpatch_bitsandbytes_compute_dtype(): def patch_vllm(): - # patch_bitsandbytes_quant_state() - # patch_vllm_bitsandbytes() + # Temporary patch to disable multiprocessing for vLLM + # Allows accessing model_executor + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" + patch_bitsandbytes_quant_state() + patch_vllm_bitsandbytes() patch_vllm_lora_tokenizer() patch_vllm_lora_load_tensors() global LORA_REQUEST_ID @@ -442,12 +445,28 @@ def vllm_dynamic_quant_supported( def get_vllm_state_dict(llm, return_state_dict = False, config = None): # All Unsloth Zoo code licensed under LGPLv3 # Unmerges vLLM modules and returns HF equivalent state_dict + # vllm_state_dict = {} try: llm_engine = getattr(llm, "llm_engine", getattr(llm, "engine", llm)) vllm_internals = llm_engine.model_executor.driver_worker.model_runner.model + + # for name, p in vllm_internals.named_parameters(): + # vllm_state_dict[name] = p except: - raise RuntimeError("Unsloth: Failed to access llm.llm_engine.model_executor.driver_worker.model_runner.model") + # Using a new VLLM version must use collective_rpc + try: + vllm_state_dict = {} + gpu_ids = llm.collective_rpc("report_device_id", args = tuple()) + weights = llm.collective_rpc("get_weight_ipc_handles", args = tuple())[0] + weights = weights[gpu_ids[0]] + for weight_name, (to_cuda_fx, cuda_data,) in weights.items(): + vllm_state_dict[weight_name] = to_cuda_fx(*cuda_data) + pass + raise NotImplementedError("Unsloth: Currently vLLM RPC is not yet fully enabled!") + except Exception as e: + raise RuntimeError(f"Unsloth: Cannot get internal vLLM states with error = {str(e)}") pass + assert(config is not None) vocab_size = config.vocab_size @@ -516,15 +535,22 @@ def get_state_dict(prefix, kk, state_dict, proj): proj = vllm_internals.model.layers[kk].mlp.down_proj get_state_dict(f"model.layers.{kk}.mlp.down_proj", 0, state_dict, proj) - state_dict[f"model.layers.{kk}.input_layernorm.weight"] = \ - vllm_internals.model.layers[kk].input_layernorm.state_dict()["weight"] - quant_state_dict[f"model.layers.{kk}.input_layernorm.weight"] = \ - state_dict[f"model.layers.{kk}.input_layernorm.weight"] - - state_dict[f"model.layers.{kk}.post_attention_layernorm.weight"] = \ - vllm_internals.model.layers[kk].post_attention_layernorm.state_dict()["weight"] - quant_state_dict[f"model.layers.{kk}.post_attention_layernorm.weight"] = \ - state_dict[f"model.layers.{kk}.post_attention_layernorm.weight"] + for layernorm_name in [ + f"model.layers.{kk}.input_layernorm", + f"model.layers.{kk}.post_attention_layernorm", + f"model.layers.{kk}.pre_feedforward_layernorm", # Gemma3 + f"model.layers.{kk}.post_feedforward_layernorm", # Gemma3 + f"model.layers.{kk}.self_attn.q_norm", # Qwen3, Gemma3 + f"model.layers.{kk}.self_attn.k_norm", # Qwen3, Gemma3 + ]: + vllm_name = layernorm_name.replace(f".{kk}.", f"[{kk}].") + vllm_name = f"vllm_internals.{vllm_name}" + try: + layernorm = eval(vllm_name).state_dict()["weight"] + state_dict[layernorm_name + ".weight"] = layernorm + except: + print(f"vllm_internals.{layernorm_name}") + pass pass # Norm @@ -1064,6 +1090,8 @@ def load_vllm( enforce_eager = enforce_eager, swap_space = swap_space, # Low memory devices like Colab (13GB) default 4GB device = device, + # New vLLM versions need to pass this in! + # worker_extension_cls = "unsloth_zoo.vllm_rlhf_utils.ColocateWorkerExtension", ) good_keys = inspect.signature(AsyncEngineArgs if use_async else EngineArgs).parameters.keys() old_keys = engine_args.keys()