Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
746 commits
Select commit Hold shift + click to select a range
be53fda
Update vision_utils.py
danielhanchen Mar 14, 2025
28f4df4
Update vision_utils.py
danielhanchen Mar 14, 2025
ad13d0a
train on completions VLMs
danielhanchen Mar 14, 2025
370cbd7
Update dataset_utils.py
danielhanchen Mar 14, 2025
bd60d26
Update dataset_utils.py
danielhanchen Mar 14, 2025
29ed559
Update dataset_utils.py
danielhanchen Mar 14, 2025
e0a4416
Update dataset_utils.py
danielhanchen Mar 14, 2025
d6e55ca
VLM train only on completions
danielhanchen Mar 14, 2025
adf8307
Update loss_utils.py
danielhanchen Mar 14, 2025
98d5885
Update dataset_utils.py
danielhanchen Mar 14, 2025
967c2ba
Update compiler.py
danielhanchen Mar 14, 2025
ddf2b8e
Update compiler.py
danielhanchen Mar 14, 2025
cb2f6c7
Update compiler.py
danielhanchen Mar 14, 2025
873c514
Update compiler.py
danielhanchen Mar 14, 2025
ca0b499
Update compiler.py
danielhanchen Mar 14, 2025
4908a16
Update compiler.py
danielhanchen Mar 14, 2025
1d4b5d7
Update compiler.py
danielhanchen Mar 14, 2025
81b45c6
Update saving_utils.py
danielhanchen Mar 14, 2025
261ffd2
Update llama_cpp.py
danielhanchen Mar 14, 2025
2ed281a
Update llama_cpp.py
danielhanchen Mar 14, 2025
d89a8fa
Update saving_utils.py
danielhanchen Mar 14, 2025
106736a
Update saving_utils.py
danielhanchen Mar 14, 2025
4abfdcd
Update __init__.py
danielhanchen Mar 14, 2025
0ac4464
Update compiler.py
danielhanchen Mar 14, 2025
e2fbe79
Update loss_utils.py
danielhanchen Mar 14, 2025
82665d4
Update compiler.py
danielhanchen Mar 14, 2025
9b6142e
Update loss_utils.py
danielhanchen Mar 14, 2025
ee92817
Update loss_utils.py
danielhanchen Mar 14, 2025
9b7600d
Update llama_cpp.py
danielhanchen Mar 14, 2025
d5b6d1c
Update loss_utils.py
danielhanchen Mar 14, 2025
86516ad
Update compiler.py
danielhanchen Mar 14, 2025
29553e4
Update llama_cpp.py
danielhanchen Mar 14, 2025
33e6c8e
Update compiler.py
danielhanchen Mar 14, 2025
5202605
Update vllm_utils.py
danielhanchen Mar 14, 2025
ca52896
Update rl_replacements.py
danielhanchen Mar 14, 2025
7baa442
Update rl_replacements.py
danielhanchen Mar 14, 2025
7ff5a1a
Update rl_replacements.py
danielhanchen Mar 14, 2025
e80aa10
Update rl_replacements.py
danielhanchen Mar 14, 2025
e93d93f
Update rl_replacements.py
danielhanchen Mar 14, 2025
9a6c231
Update rl_replacements.py
danielhanchen Mar 14, 2025
c8abd45
Update training_utils.py
danielhanchen Mar 14, 2025
1633c78
Merge branch 'main' into nightly
danielhanchen Mar 15, 2025
964129b
Update dataset_utils.py
danielhanchen Mar 15, 2025
3b690ad
Update dataset_utils.py
danielhanchen Mar 16, 2025
7bb4a13
Revert "Update dataset_utils.py"
danielhanchen Mar 16, 2025
947c5e9
Update temporary_patches.py
danielhanchen Mar 16, 2025
2fe9c6c
Update temporary_patches.py
danielhanchen Mar 16, 2025
0b2dc97
Update temporary_patches.py
danielhanchen Mar 16, 2025
b9a96dc
Update temporary_patches.py
danielhanchen Mar 16, 2025
0784a07
Update temporary_patches.py
danielhanchen Mar 16, 2025
80c2dc8
Update temporary_patches.py
danielhanchen Mar 16, 2025
26c817d
Update compiler.py
danielhanchen Mar 16, 2025
d3cdd17
Update compiler.py
danielhanchen Mar 16, 2025
31e778a
Remove prints
danielhanchen Mar 16, 2025
2c6a3c5
Update compiler.py
danielhanchen Mar 16, 2025
f3f3c9c
Update saving_utils.py
danielhanchen Mar 16, 2025
93b6a88
Update temporary_patches.py
danielhanchen Mar 16, 2025
86aee5c
Update __init__.py
danielhanchen Mar 16, 2025
ac38bff
Update pyproject.toml
danielhanchen Mar 16, 2025
f64e153
Update vllm_utils.py
danielhanchen Mar 16, 2025
4c72e79
bug fix #2008 unsloth issue - load_in_4bit = True + fast_inference = …
void-mckenzie Mar 16, 2025
1974798
Update dataset_utils.py
danielhanchen Mar 16, 2025
4df4417
Merge branch 'nightly' of https://github.com/unslothai/unsloth-zoo in…
danielhanchen Mar 16, 2025
a5c20e1
Update compiler.py
danielhanchen Mar 17, 2025
a434d45
Update temporary_patches.py
danielhanchen Mar 17, 2025
3cfb98f
Gemma 3 fixes
danielhanchen Mar 17, 2025
fc5f1c0
Update temporary_patches.py
danielhanchen Mar 17, 2025
b317e90
Update compiler.py
danielhanchen Mar 17, 2025
4121dd0
Update compiler.py
danielhanchen Mar 17, 2025
c59dcde
Gemma 3 fixes
danielhanchen Mar 17, 2025
d98ae2e
Update patching_utils.py
danielhanchen Mar 17, 2025
3073ea3
Update compiler.py
danielhanchen Mar 17, 2025
57ff5f6
Update compiler.py
danielhanchen Mar 17, 2025
c7e803b
Update patching_utils.py
danielhanchen Mar 17, 2025
3daaf0d
Update temporary_patches.py
danielhanchen Mar 17, 2025
b619b58
Update compiler.py
danielhanchen Mar 17, 2025
4e78082
Update compiler.py
danielhanchen Mar 17, 2025
c8ba677
Update temporary_patches.py
danielhanchen Mar 17, 2025
fb68ecc
Update temporary_patches.py
danielhanchen Mar 17, 2025
e5a73fe
Update temporary_patches.py
danielhanchen Mar 17, 2025
d7bbe30
Update temporary_patches.py
danielhanchen Mar 17, 2025
5f99275
Update temporary_patches.py
danielhanchen Mar 17, 2025
346812f
Update temporary_patches.py
danielhanchen Mar 17, 2025
b907d0c
Update temporary_patches.py
danielhanchen Mar 17, 2025
789171c
Update temporary_patches.py
danielhanchen Mar 17, 2025
4e2c94a
Update compiler.py
danielhanchen Mar 17, 2025
4740c99
Update compiler.py
danielhanchen Mar 17, 2025
4658d94
Update compiler.py
danielhanchen Mar 17, 2025
f9de6e9
Update compiler.py
danielhanchen Mar 17, 2025
dbdbc63
Update compiler.py
danielhanchen Mar 17, 2025
55b1963
Update compiler.py
danielhanchen Mar 17, 2025
e997ee1
Update compiler.py
danielhanchen Mar 17, 2025
0ba033f
Update compiler.py
danielhanchen Mar 17, 2025
bf821ba
Update compiler.py
danielhanchen Mar 17, 2025
d8c6e59
Update compiler.py
danielhanchen Mar 17, 2025
9967ce3
Update compiler.py
danielhanchen Mar 17, 2025
7b0c535
Update compiler.py
danielhanchen Mar 17, 2025
e6859ce
Update compiler.py
danielhanchen Mar 17, 2025
b2a8f47
Update compiler.py
danielhanchen Mar 17, 2025
ca79c93
Update compiler.py
danielhanchen Mar 17, 2025
3f67ed6
Update compiler.py
danielhanchen Mar 17, 2025
e5fb044
Update compiler.py
danielhanchen Mar 18, 2025
4a1bf2f
Update compiler.py
danielhanchen Mar 18, 2025
36ec4ee
Update compiler.py
danielhanchen Mar 18, 2025
7d1dc81
compiler
danielhanchen Mar 18, 2025
16d6137
Update gradient_checkpointing.py
danielhanchen Mar 18, 2025
9b78566
Update temporary_patches.py
danielhanchen Mar 18, 2025
e0edefe
Update temporary_patches.py
danielhanchen Mar 18, 2025
719e379
Update temporary_patches.py
danielhanchen Mar 18, 2025
8beb2b7
Update temporary_patches.py
danielhanchen Mar 18, 2025
f9cf701
Update temporary_patches.py
danielhanchen Mar 18, 2025
aa8848c
Update temporary_patches.py
danielhanchen Mar 18, 2025
ee940a9
Update temporary_patches.py
danielhanchen Mar 18, 2025
d086158
Update temporary_patches.py
danielhanchen Mar 18, 2025
5a43de2
Update temporary_patches.py
danielhanchen Mar 18, 2025
1f6589b
Update temporary_patches.py
danielhanchen Mar 18, 2025
9b904a9
Update temporary_patches.py
danielhanchen Mar 18, 2025
3c0504b
Update temporary_patches.py
danielhanchen Mar 18, 2025
417161e
Update temporary_patches.py
danielhanchen Mar 18, 2025
3f024b6
Update temporary_patches.py
danielhanchen Mar 18, 2025
b0bd2f4
Update temporary_patches.py
danielhanchen Mar 18, 2025
640e071
Update temporary_patches.py
danielhanchen Mar 18, 2025
05c2232
Update temporary_patches.py
danielhanchen Mar 18, 2025
593eecb
Update temporary_patches.py
danielhanchen Mar 18, 2025
e9c935f
Update temporary_patches.py
danielhanchen Mar 18, 2025
b71160c
causal mask dtype
danielhanchen Mar 18, 2025
a6fedb6
Fix checkpoint and save from local file (#74)
Erland366 Mar 18, 2025
c566b02
Update patching_utils.py
danielhanchen Mar 18, 2025
aaf5feb
Merge branch 'nightly' of https://github.com/unslothai/unsloth-zoo in…
danielhanchen Mar 18, 2025
94f5f4f
Update patching_utils.py
danielhanchen Mar 18, 2025
26c67cf
Update temporary_patches.py
danielhanchen Mar 18, 2025
d92bab6
Update temporary_patches.py
danielhanchen Mar 18, 2025
b04cf4b
Update compiler.py
danielhanchen Mar 18, 2025
4565db3
Update loss_utils.py
danielhanchen Mar 18, 2025
e368810
Update compiler.py
danielhanchen Mar 18, 2025
ce07e0f
Update vllm_utils.py
danielhanchen Mar 18, 2025
114150d
Update compiler.py
danielhanchen Mar 18, 2025
6bd69f1
Update peft_utils.py
danielhanchen Mar 18, 2025
9cee216
Update rl_replacements.py
danielhanchen Mar 18, 2025
df8ac03
Update vllm_utils.py
danielhanchen Mar 18, 2025
e5a321f
Update temporary_patches.py
danielhanchen Mar 18, 2025
134857d
Update temporary_patches.py
danielhanchen Mar 18, 2025
dec6433
Update temporary_patches.py
danielhanchen Mar 18, 2025
b14149b
Update temporary_patches.py
danielhanchen Mar 18, 2025
07f7dde
Update temporary_patches.py
danielhanchen Mar 18, 2025
7600d35
Update temporary_patches.py
danielhanchen Mar 18, 2025
679edeb
Update temporary_patches.py
danielhanchen Mar 18, 2025
5fd25ec
Update temporary_patches.py
danielhanchen Mar 18, 2025
a884b3c
Update temporary_patches.py
danielhanchen Mar 18, 2025
b6ab8bd
Update temporary_patches.py
danielhanchen Mar 18, 2025
cc3ca48
Update temporary_patches.py
danielhanchen Mar 18, 2025
9f5b67d
Update temporary_patches.py
danielhanchen Mar 18, 2025
e4980b2
Update temporary_patches.py
danielhanchen Mar 18, 2025
d745fb7
Update temporary_patches.py
danielhanchen Mar 18, 2025
201c1ab
Merge branch 'main' into nightly
danielhanchen Mar 18, 2025
2fb83f0
Update compiler.py
danielhanchen Mar 18, 2025
3551715
Update vllm_lora_worker_manager.py
danielhanchen Mar 19, 2025
ab47b77
Update utils.py
danielhanchen Mar 19, 2025
ceed6ab
Update temporary_patches.py
danielhanchen Mar 19, 2025
b5611c2
Update temporary_patches.py
danielhanchen Mar 19, 2025
480aaf7
Update temporary_patches.py
danielhanchen Mar 19, 2025
637c7ad
Update temporary_patches.py
danielhanchen Mar 19, 2025
5a224bb
Update temporary_patches.py
danielhanchen Mar 19, 2025
2248156
Update temporary_patches.py
danielhanchen Mar 19, 2025
ee6ed2b
Update temporary_patches.py
danielhanchen Mar 19, 2025
6d10b9b
Update temporary_patches.py
danielhanchen Mar 19, 2025
9d431b0
Update temporary_patches.py
danielhanchen Mar 19, 2025
42491ca
Update temporary_patches.py
danielhanchen Mar 19, 2025
d3ddadf
Update vllm_utils.py
danielhanchen Mar 19, 2025
dbc6a43
Update vllm_utils.py
danielhanchen Mar 19, 2025
0c4b0d2
Update vllm_utils.py
danielhanchen Mar 19, 2025
5504033
Update vllm_utils.py
danielhanchen Mar 19, 2025
2a84e79
Update dataset_utils.py
danielhanchen Mar 19, 2025
cbbc4a3
bidirectional attention
danielhanchen Mar 19, 2025
3bf532d
Update vllm_utils.py
danielhanchen Mar 19, 2025
8e687b5
Update __init__.py
danielhanchen Mar 19, 2025
a723520
Update temporary_patches.py
danielhanchen Mar 19, 2025
9d1dd42
Update temporary_patches.py
danielhanchen Mar 19, 2025
aec2701
Update temporary_patches.py
danielhanchen Mar 19, 2025
23a3a59
Update vllm_utils.py
danielhanchen Mar 19, 2025
2874477
Update vllm_utils.py
danielhanchen Mar 19, 2025
7d40491
Update vllm_utils.py
danielhanchen Mar 19, 2025
2275642
Update vllm_utils.py
danielhanchen Mar 19, 2025
9cd348f
Update vllm_utils.py
danielhanchen Mar 19, 2025
6e33fa9
Update vllm_utils.py
danielhanchen Mar 19, 2025
7ad0f55
Update vllm_lora_worker_manager.py
danielhanchen Mar 19, 2025
7fd23a0
Update vllm_lora_worker_manager.py
danielhanchen Mar 19, 2025
9176758
Update vllm_lora_worker_manager.py
danielhanchen Mar 19, 2025
b5a38b0
Merge branch 'main' into nightly
danielhanchen Mar 19, 2025
446787d
Merge branch 'main' into nightly
danielhanchen Mar 19, 2025
d2bdd9b
Update temporary_patches.py
danielhanchen Mar 19, 2025
83bde7d
Update temporary_patches.py
danielhanchen Mar 19, 2025
0fe9eaa
Update temporary_patches.py
danielhanchen Mar 19, 2025
3d70a80
Update temporary_patches.py
danielhanchen Mar 19, 2025
6b6587d
Merge branch 'main' into nightly
danielhanchen Mar 21, 2025
88301c5
Update loss_utils.py
danielhanchen Mar 21, 2025
debc0e8
Update loss_utils.py
danielhanchen Mar 21, 2025
7dc2e9d
Update loss_utils.py
danielhanchen Mar 21, 2025
57b4973
Update loss_utils.py
danielhanchen Mar 21, 2025
3cfa271
Update loss_utils.py
danielhanchen Mar 21, 2025
1f5b6f2
Update __init__.py
danielhanchen Mar 21, 2025
2f3c87b
fix: AsyncLLMEngine bugs (#82)
bradhilton Mar 22, 2025
64dd76c
fixed a typo in L119, removing unnecessary len() (#84)
SpaceHunterInf Mar 22, 2025
5a1a2b5
Merge branch 'main' into nightly
danielhanchen Mar 22, 2025
a62e4c6
Fix gradient checkpointing warning filter implementation
rolandtannous Mar 24, 2025
d115cea
Input grads fix for gemma3 (#96)
mmathew23 Mar 25, 2025
454757c
Merge pull request #97 from rolandtannous/fix/suppress-gradient-check…
shimmyshimmer Mar 25, 2025
c50123a
Update vision_utils.py
danielhanchen Mar 26, 2025
b199491
Vision requires grad
danielhanchen Mar 26, 2025
1670fa6
Check SDPA for Mistral / Pixtral
danielhanchen Mar 26, 2025
e32f797
Update compiler.py
danielhanchen Mar 26, 2025
b9d9cc5
Update vision_utils.py
danielhanchen Mar 26, 2025
5e3c88f
Update vision_utils.py
danielhanchen Mar 26, 2025
5c4086c
Update vision_utils.py
danielhanchen Mar 26, 2025
0599242
Update __init__.py
danielhanchen Mar 26, 2025
8da5939
Update vision_utils.py
danielhanchen Mar 26, 2025
db90dca
Update vision_utils.py
danielhanchen Mar 26, 2025
51cefe5
Update vision_utils.py
danielhanchen Mar 26, 2025
20b42ce
Update vision_utils.py
danielhanchen Mar 26, 2025
b03ded6
Update vision_utils.py
danielhanchen Mar 26, 2025
65469c2
Update vision_utils.py
danielhanchen Mar 26, 2025
7f4eb00
Update vision_utils.py
danielhanchen Mar 26, 2025
8584b5d
Update vision_utils.py
danielhanchen Mar 26, 2025
8bb6b55
Update vision_utils.py
danielhanchen Mar 26, 2025
20a61b0
Update vision_utils.py
danielhanchen Mar 26, 2025
86ca6d5
Update vision_utils.py
danielhanchen Mar 26, 2025
0221094
Update vision_utils.py
danielhanchen Mar 26, 2025
d13ebf7
Update vision_utils.py
danielhanchen Mar 26, 2025
f6c4b2e
Update vision_utils.py
danielhanchen Mar 26, 2025
2d1e506
Update vllm_utils.py (#99)
5k5000 Mar 26, 2025
23e018f
Update vision_utils.py
danielhanchen Mar 26, 2025
9f1eaa2
Fixes to support IterableDataset (#98)
marcandrelarochelle Mar 26, 2025
affb9d8
Merge branch 'nightly' of https://github.com/unslothai/unsloth-zoo in…
danielhanchen Mar 26, 2025
1efc541
Merge branch 'main' into nightly
danielhanchen Mar 26, 2025
6ae6d0e
Merge branch 'main' into nightly
danielhanchen May 6, 2025
8986b95
Update vllm_utils.py
danielhanchen May 6, 2025
d37ed39
Create vllm_rlhf_utils.py
danielhanchen May 6, 2025
abf388b
Update vllm_rlhf_utils.py
danielhanchen May 6, 2025
6ff1836
Update vllm_rlhf_utils.py
danielhanchen May 6, 2025
deea45f
Update vllm_rlhf_utils.py
danielhanchen May 6, 2025
4ae4dd6
Update vllm_rlhf_utils.py
danielhanchen May 7, 2025
8dcff39
Update vllm_rlhf_utils.py
danielhanchen May 7, 2025
6edc24c
Update vllm_rlhf_utils.py
danielhanchen May 7, 2025
e61ad9b
Update vllm_rlhf_utils.py
danielhanchen May 7, 2025
f74ba8e
Update vllm_rlhf_utils.py
danielhanchen May 8, 2025
ee5bb55
Update vllm_rlhf_utils.py
danielhanchen May 8, 2025
994533c
vLLM for Qwen 3
danielhanchen May 11, 2025
23d8aa7
Merge branch 'main' into nightly
danielhanchen May 11, 2025
569c3ae
Update vllm_utils.py
danielhanchen May 11, 2025
847ad07
Update vllm_utils.py
danielhanchen May 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions unsloth_zoo/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def post_patch_loss_function(model):
pass


torch_cuda_device = torch.cuda.device
def fused_linear_cross_entropy(
hidden_states : torch.Tensor,
lm_weight : torch.Tensor,
Expand All @@ -167,16 +168,17 @@ def fused_linear_cross_entropy(
# All Unsloth Zoo code licensed under LGPLv3
reduction = "sum" if num_items_in_batch is not None else "mean"
if logit_softcapping == 0: logit_softcapping = None
loss = linear_cross_entropy(
hidden_states.to(lm_weight.dtype),
lm_weight,
targets = labels,
ignore_index = ignore_index,
softcap = logit_softcapping,
reduction = reduction,
shift = True,
filter_eps = accuracy_threshold,
)
with torch_cuda_device(lm_weight.device):
loss = linear_cross_entropy(
hidden_states.to(lm_weight.dtype),
lm_weight,
targets = labels,
ignore_index = ignore_index,
softcap = logit_softcapping,
reduction = reduction,
shift = True,
filter_eps = accuracy_threshold,
)
if num_items_in_batch is not None: loss = loss / num_items_in_batch
return loss
pass
Expand Down
150 changes: 150 additions & 0 deletions unsloth_zoo/vllm_rlhf_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Unsloth Zoo - Utilities for Unsloth
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

import torch
__all__ = [
"WorkerExtension",
"ColocateWorkerExtension",
]

def stateless_init_process_group(master_address, master_port, rank, world_size,
device):
"""
vLLM provides `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed.
It is recommended to create `StatelessProcessGroup`, and then initialize
the data-plane communication (NCCL) between external (train processes)
and vLLM workers.
"""
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
pg = StatelessProcessGroup.create(host=master_address,
port=master_port,
rank=rank,
world_size=world_size)
pynccl = PyNcclCommunicator(pg, device=device)
return pynccl


class WorkerExtension:
"""
The class for vLLM's worker to inherit from.
By defining an extension class, the code can work no matter what is
the underlying worker class. This way, the code can be compatible
with both vLLM V0 and V1.
NOTE: we define this class in a separate module, and the main module
should pass the full qualified name as `worker_extension_cls` argument.
"""

def init_weight_update_group(self, master_address, master_port,
rank_offset, world_size):
from vllm.distributed.parallel_state import get_world_group
rank = get_world_group().rank + rank_offset
self.model_update_group = stateless_init_process_group(
master_address,
master_port,
rank,
world_size,
self.device,
)

def update_weight(self, name, dtype, shape):
weight = torch.empty(shape, dtype=dtype, device="cuda")
self.model_update_group.broadcast(weight,
src=0,
stream=torch.cuda.current_stream())

self.model_runner.model.load_weights(weights=[(name, weight)])

del weight

def check_weights_changed(self):
"""
Check if the weights are updated to 0.
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose(
p, torch.zeros_like(p))
return weights_updated


class ColocateWorkerExtension:
"""
The class for vLLM's worker to inherit from, in the colocate setting.
By defining an extension class, the code can work no matter what is
the underlying worker class. This way, the code can be compatible
with both vLLM V0 and V1.
NOTE: we define this class in a separate module, and the main module
should pass the full qualified name as `worker_extension_cls` argument.
"""

def report_device_id(self) -> str:
from vllm.platforms import current_platform
self.device_uuid = current_platform.get_device_uuid(self.device.index)
return self.device_uuid

def update_weights_from_ipc_handles(self, ipc_handles):
handles = ipc_handles[self.device_uuid]
device_id = self.device.index
weights = []
for name, handle in handles.items():
func, args = handle
list_args = list(args)
# the key is to change device id to the current device id
# in case two processes have different CUDA_VISIBLE_DEVICES
list_args[6] = device_id
tensor = func(*list_args)
weights.append((name, tensor))
self.model_runner.model.load_weights(weights=weights)
torch.cuda.synchronize()

def get_model_runner(self):
vllm_model = self.model_runner.model
model_loras_A, model_loras_B = [], []
vllm_loras_A, vllm_loras_B = [], []
parameters = []
for v_layer in vllm_model.model.layers:
print(v_layer.self_attn.qkv_proj.lora_a_stacked[0])
vllm_loras_A .append(v_layer.self_attn.qkv_proj.lora_a_stacked[0])
vllm_loras_A .append(v_layer.self_attn.qkv_proj.lora_a_stacked[1])
vllm_loras_A .append(v_layer.self_attn.qkv_proj.lora_a_stacked[2])

# parameters.append((name, param))
torch.cuda.synchronize()
return vllm_loras_A

def check_weights_changed(self):
"""
Check if the weights are updated to 0.
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose(
p, torch.zeros_like(p))
return weights_updated

def get_weight_ipc_handles(self):
from torch.multiprocessing.reductions import reduce_tensor
data = {}
vllm_model = self.model_runner.model
for name, p in vllm_model.named_parameters():
# the training actor might only have a subset of the weights
# and need to all-gather the weights from all the actors.
# for demonstration, here we assume all training actors have
# the full weights.
data[name] = reduce_tensor(p.detach())
return {self.device_uuid: data}
52 changes: 40 additions & 12 deletions unsloth_zoo/vllm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,11 @@ def unpatch_bitsandbytes_compute_dtype():


def patch_vllm():
# patch_bitsandbytes_quant_state()
# patch_vllm_bitsandbytes()
# Temporary patch to disable multiprocessing for vLLM
# Allows accessing model_executor
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
patch_bitsandbytes_quant_state()
patch_vllm_bitsandbytes()
patch_vllm_lora_tokenizer()
patch_vllm_lora_load_tensors()
global LORA_REQUEST_ID
Expand Down Expand Up @@ -442,12 +445,28 @@ def vllm_dynamic_quant_supported(
def get_vllm_state_dict(llm, return_state_dict = False, config = None):
# All Unsloth Zoo code licensed under LGPLv3
# Unmerges vLLM modules and returns HF equivalent state_dict
# vllm_state_dict = {}
try:
llm_engine = getattr(llm, "llm_engine", getattr(llm, "engine", llm))
vllm_internals = llm_engine.model_executor.driver_worker.model_runner.model

# for name, p in vllm_internals.named_parameters():
# vllm_state_dict[name] = p
except:
raise RuntimeError("Unsloth: Failed to access llm.llm_engine.model_executor.driver_worker.model_runner.model")
# Using a new VLLM version must use collective_rpc
try:
vllm_state_dict = {}
gpu_ids = llm.collective_rpc("report_device_id", args = tuple())
weights = llm.collective_rpc("get_weight_ipc_handles", args = tuple())[0]
weights = weights[gpu_ids[0]]
for weight_name, (to_cuda_fx, cuda_data,) in weights.items():
vllm_state_dict[weight_name] = to_cuda_fx(*cuda_data)
pass
raise NotImplementedError("Unsloth: Currently vLLM RPC is not yet fully enabled!")
except Exception as e:
raise RuntimeError(f"Unsloth: Cannot get internal vLLM states with error = {str(e)}")
pass

assert(config is not None)
vocab_size = config.vocab_size

Expand Down Expand Up @@ -516,15 +535,22 @@ def get_state_dict(prefix, kk, state_dict, proj):
proj = vllm_internals.model.layers[kk].mlp.down_proj
get_state_dict(f"model.layers.{kk}.mlp.down_proj", 0, state_dict, proj)

state_dict[f"model.layers.{kk}.input_layernorm.weight"] = \
vllm_internals.model.layers[kk].input_layernorm.state_dict()["weight"]
quant_state_dict[f"model.layers.{kk}.input_layernorm.weight"] = \
state_dict[f"model.layers.{kk}.input_layernorm.weight"]

state_dict[f"model.layers.{kk}.post_attention_layernorm.weight"] = \
vllm_internals.model.layers[kk].post_attention_layernorm.state_dict()["weight"]
quant_state_dict[f"model.layers.{kk}.post_attention_layernorm.weight"] = \
state_dict[f"model.layers.{kk}.post_attention_layernorm.weight"]
for layernorm_name in [
f"model.layers.{kk}.input_layernorm",
f"model.layers.{kk}.post_attention_layernorm",
f"model.layers.{kk}.pre_feedforward_layernorm", # Gemma3
f"model.layers.{kk}.post_feedforward_layernorm", # Gemma3
f"model.layers.{kk}.self_attn.q_norm", # Qwen3, Gemma3
f"model.layers.{kk}.self_attn.k_norm", # Qwen3, Gemma3
]:
vllm_name = layernorm_name.replace(f".{kk}.", f"[{kk}].")
vllm_name = f"vllm_internals.{vllm_name}"
try:
layernorm = eval(vllm_name).state_dict()["weight"]
state_dict[layernorm_name + ".weight"] = layernorm
except:
print(f"vllm_internals.{layernorm_name}")
pass
pass

# Norm
Expand Down Expand Up @@ -1064,6 +1090,8 @@ def load_vllm(
enforce_eager = enforce_eager,
swap_space = swap_space, # Low memory devices like Colab (13GB) default 4GB
device = device,
# New vLLM versions need to pass this in!
# worker_extension_cls = "unsloth_zoo.vllm_rlhf_utils.ColocateWorkerExtension",
)
good_keys = inspect.signature(AsyncEngineArgs if use_async else EngineArgs).parameters.keys()
old_keys = engine_args.keys()
Expand Down