-
Notifications
You must be signed in to change notification settings - Fork 2.8k
/
Copy pathlora_converter.py
223 lines (183 loc) · 9.51 KB
/
lora_converter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import re
import tarfile
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Tuple
import torch
import yaml
from nemo.export.tarutils import TarPath
def replace_number_add_offset(key, offset_value):
# This function finds the layer number in the state dict key and adds a numeric offset to that number
if offset_value == 0:
return key
pattern = r'layers.(\d+)'
def add_offset(match):
return "layers." + str(int(match.group(1)) + offset_value)
return re.sub(pattern, add_offset, key)
def rename_qkv_keys(key):
new_keys = []
new_keys.append(key.replace(".lora_kqv_adapter.", ".lora_unfused_kqv_adapter.q_adapter."))
new_keys.append(key.replace(".lora_kqv_adapter.", ".lora_unfused_kqv_adapter.k_adapter."))
new_keys.append(key.replace(".lora_kqv_adapter.", ".lora_unfused_kqv_adapter.v_adapter."))
return new_keys
def reformat_module_names_to_hf(tensors: Dict[str, torch.Tensor]) -> Tuple[Dict[str, torch.Tensor], List[str]]:
new_tensors = dict()
module_names = set()
known_module_names = ["q_proj", "k_proj", "v_proj", "o_proj", "down_proj", "gate_proj", "up_proj"]
for module_name, module_weight in tensors.items():
# map linear_in and linear_out to lora_a/lora_b counterparts
new_module_name = "base_model." + module_name.replace("linear_in", "lora_A").replace("linear_out", "lora_B")
# map target modules to their vLLM/HF counterparts
new_module_name = new_module_name.replace("q_adapter", "q_proj")
new_module_name = new_module_name.replace("k_adapter", "k_proj")
new_module_name = new_module_name.replace("v_adapter", "v_proj")
new_module_name = new_module_name.replace("lora_dense_attention_adapter", "o_proj")
new_module_name = new_module_name.replace("lora_4htoh_adapter", "down_proj")
new_module_name = new_module_name.replace("gate_adapter", "gate_proj")
new_module_name = new_module_name.replace("up_adapter", "up_proj")
# map other parts of the module names to fit vLLM/huggingface
new_module_name = new_module_name.replace(".adapter_layer", "")
new_module_name = new_module_name.replace(".lora_unfused_kqv_proj", "")
new_module_name = new_module_name.replace(".lora_unfused_hto4h_adapter", "")
new_module_name = new_module_name.replace("self_attention", "self_attn")
new_module_name = new_module_name.replace("decoder", "model")
new_tensors[new_module_name] = module_weight
# keep track of the modules that we've added to store them in the config file
for kmn in known_module_names:
if f'.{kmn}' in new_module_name:
module_names.add(kmn)
return (new_tensors, list(module_names))
def convert_lora_weights_to_canonical(
config: Dict[str, Any], lora_weights: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""This function converts nemo style (fused) lora weights to canonical (unfused)
LoRA weights. Namely, it unfuses the QKV adapter layers and the H-to-4H adapter layers.
Returns:
Dict[str, torch.Tensor]: The new LoRA weights with unfused layers.
"""
hidden_size = int(config["hidden_size"])
num_heads = int(config["num_attention_heads"])
head_size = hidden_size // num_heads
num_query_groups = int(config.get("num_query_groups", num_heads)) # num_kv_heads
heads_per_group = num_heads // num_query_groups
qkv_total_dim = num_heads + 2 * num_query_groups
adapter_size = config['peft']['lora_tuning']['adapter_dim']
q_slice = torch.cat(
[
torch.arange((heads_per_group + 2) * group_idx, (heads_per_group + 2) * group_idx + heads_per_group)
for group_idx in range(num_query_groups)
]
)
k_slice = torch.arange(heads_per_group, qkv_total_dim, heads_per_group + 2)
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, heads_per_group + 2)
qkv_keys_to_update = []
hto4h_keys_to_update = []
for key in lora_weights.keys():
if "lora_kqv_adapter" in key:
qkv_keys_to_update.append(key)
if "lora_hto4h_adapter" in key:
hto4h_keys_to_update.append(key)
# unfuse QKV layer
for key in qkv_keys_to_update:
if "linear_in" in key:
assert lora_weights[key].size(0) == adapter_size
for new_key in rename_qkv_keys(key):
lora_weights[new_key] = lora_weights[key]
assert len(lora_weights[new_key].size()) == 2
elif "linear_out" in key:
assert lora_weights[key].size(1) == adapter_size
for new_key, size in zip(rename_qkv_keys(key), [q_slice, k_slice, v_slice]):
lora_weights[new_key] = (
lora_weights[key]
.reshape((qkv_total_dim, head_size, adapter_size))[size]
.reshape((-1, adapter_size))
)
assert len(lora_weights[new_key].size()) == 2
lora_weights.pop(key)
# This maps to gate_up_proj in HF, but we need to split it up into gate_proj and up_proj
for key in hto4h_keys_to_update:
gate_proj_key = key.replace(".lora_hto4h_adapter.", ".lora_unfused_hto4h_adapter.gate_adapter.")
up_proj_key = key.replace(".lora_hto4h_adapter.", ".lora_unfused_hto4h_adapter.up_adapter.")
module_weight = lora_weights[key]
if "linear_in" in key:
# lora_a gets duplicated
lora_weights[gate_proj_key] = module_weight
lora_weights[up_proj_key] = module_weight
elif "linear_out" in key:
# lora_b gets split
split_size = module_weight.shape[0]
gate_up_split = module_weight.split(split_size // 2)
lora_weights[gate_proj_key] = gate_up_split[0]
lora_weights[up_proj_key] = gate_up_split[1]
lora_weights.pop(key)
return lora_weights
def convert_lora_nemo_to_canonical(lora_nemo, save_path, hf_format=False, donor_hf_config=None):
with TarPath(lora_nemo) as archive:
with (archive / "model_config.yaml").open("r") as config_file:
lora_config = yaml.load(config_file, Loader=yaml.SafeLoader)
tp_size = lora_config.get('tensor_model_parallel_size', 1)
pp_size = lora_config.get('pipeline_model_parallel_size', 1)
lora_state_dict = [{}] * tp_size
for pp in range(pp_size):
for tp in range(tp_size):
if tp_size == 1:
ckpt_file = archive / "model_weights.ckpt"
elif pp_size == 1:
ckpt_file = archive / f"mp_rank_{tp:02d}/model_weights.ckpt"
else:
ckpt_file = archive / f"tp_rank_{tp:02d}_pp_rank_{pp:03d}/model_weights.ckpt"
with ckpt_file.open("rb") as f:
weights = torch.load(f, map_location=torch.device('cpu'))
if pp == 0:
lora_state_dict[tp] = weights
else:
# calculate layer offset
layer_offset = lora_config['num_layers'] // pp_size * pp
for key, value in weights.items():
new_key = replace_number_add_offset(key, layer_offset)
lora_state_dict[tp][new_key] = value
# TODO: currently suport tp=1
lora_state_dict = lora_state_dict[0]
if lora_config['peft']['lora_tuning'].get('variant', 'nemo') == "nemo":
lora_config['peft']['lora_tuning']['variant'] = "canonical"
lora_state_dict = convert_lora_weights_to_canonical(lora_config, lora_state_dict)
if hf_format:
lora_state_dict, target_modules = reformat_module_names_to_hf(lora_state_dict)
Path(save_path).mkdir(parents=True, exist_ok=True)
torch.save(lora_state_dict, f"{save_path}/adapter_model.bin")
if donor_hf_config is not None:
with open(donor_hf_config) as hf_config_file:
adapter_config = json.load(hf_config_file)
else:
adapter_config = {}
adapter_config['peft_type'] = "LORA"
adapter_config['r'] = lora_config['peft']['lora_tuning']['adapter_dim']
adapter_config['lora_alpha'] = lora_config['peft']['lora_tuning']['alpha']
adapter_config['target_modules'] = target_modules
with open(f"{save_path}/adapter_config.json", "w") as f:
json.dump(adapter_config, f, indent=4)
else:
with tempfile.TemporaryDirectory() as tmpdir:
with open(f"{tmpdir}/model_config.yaml", "w") as f:
yaml.dump(lora_config, f)
torch.save(lora_state_dict, f"{tmpdir}/model_weights.ckpt")
dirname = os.path.dirname(save_path)
os.makedirs(dirname, exist_ok=True)
with tarfile.open(save_path, "w:") as tar:
tar.add(tmpdir, arcname=".")
return lora_state_dict, lora_config