-
Notifications
You must be signed in to change notification settings - Fork 2.3k
/
finetuning.py
428 lines (389 loc) · 15.6 KB
/
finetuning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import dataclasses
import os
import random
from collections import Counter
from warnings import warn
import fire
import numpy as np
import torch
import torch.optim as optim
from accelerate.utils import is_xpu_available
from llama_recipes.configs import (
fsdp_config as FSDP_CONFIG,
quantization_config as QUANTIZATION_CONFIG,
train_config as TRAIN_CONFIG,
)
from llama_recipes.data.concatenator import ConcatDataset
from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
from llama_recipes.utils import fsdp_auto_wrap_policy
from llama_recipes.utils.config_utils import (
check_fsdp_config,
generate_dataset_config,
generate_peft_config,
get_dataloader_kwargs,
update_config,
)
from llama_recipes.utils.dataset_utils import (
get_custom_data_collator,
get_preprocessed_dataset,
)
from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
from llama_recipes.utils.train_utils import (
clear_gpu_cache,
freeze_transformer_layers,
freeze_LLM_only,
get_policies,
print_model_size,
print_frozen_model_status,
setup,
setup_environ_flags,
train,
)
from peft import get_peft_model, PeftModel
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
from torch.optim.lr_scheduler import StepLR
from transformers import (
AutoConfig,
AutoProcessor,
AutoTokenizer,
BitsAndBytesConfig,
LlamaForCausalLM,
MllamaForConditionalGeneration,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.mllama.modeling_mllama import (
MllamaCrossAttentionDecoderLayer,
MllamaSelfAttentionDecoderLayer,
MllamaVisionEncoderLayer,
)
def setup_wandb(train_config, fsdp_config, **kwargs):
try:
import wandb
except ImportError:
raise ImportError(
"You are trying to use wandb which is not currently installed. "
"Please install it using pip install wandb"
)
from llama_recipes.configs import wandb_config as WANDB_CONFIG
wandb_config = WANDB_CONFIG()
update_config(wandb_config, **kwargs)
init_dict = dataclasses.asdict(wandb_config)
run = wandb.init(**init_dict)
run.config.update(train_config)
run.config.update(fsdp_config, allow_val_change=True)
return run
def main(**kwargs):
# Update the configuration for the training and sharding process
train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
update_config((train_config, fsdp_config), **kwargs)
# Set the seeds for reproducibility
if is_xpu_available():
torch.xpu.manual_seed(train_config.seed)
torch.manual_seed(train_config.seed)
random.seed(train_config.seed)
np.random.seed(train_config.seed)
if train_config.enable_fsdp:
setup()
# torchrun specific
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if torch.distributed.is_initialized():
if is_xpu_available():
torch.xpu.set_device(local_rank)
elif torch.cuda.is_available():
torch.cuda.set_device(local_rank)
clear_gpu_cache(local_rank)
setup_environ_flags(rank)
wandb_run = None
if train_config.use_wandb:
if not train_config.enable_fsdp or rank == 0:
wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
# setting quantization configs
bnb_config = None
if train_config.quantization:
if type(train_config.quantization) == type(True):
warn(
"Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.",
FutureWarning,
)
train_config.quantization = "8bit"
if train_config.quantization == "8bit" and train_config.enable_fsdp:
raise ValueError(
"8bit quantization is not supported with FSDP, please use 4bit quantization"
)
quant_config = QUANTIZATION_CONFIG()
update_config(quant_config, **kwargs)
bnb_config = quant_config.create_bnb_config(train_config.quantization)
# Load the pre-trained model and setup its configuration
use_cache = False if train_config.enable_fsdp else None
config = AutoConfig.from_pretrained(train_config.model_name)
if config.model_type == "mllama":
is_vision = True
model = MllamaForConditionalGeneration.from_pretrained(
train_config.model_name,
quantization_config=bnb_config,
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
device_map=(
"auto"
if train_config.quantization and not train_config.enable_fsdp
else None
),
torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
)
processor = AutoProcessor.from_pretrained(
train_config.model_name
if train_config.tokenizer_name is None
else train_config.tokenizer_name
)
processor.tokenizer.padding_side = "right"
model.supports_gradient_checkpointing = True
model.language_model.supports_gradient_checkpointing = True
elif config.model_type == "llama":
is_vision = False
model = LlamaForCausalLM.from_pretrained(
train_config.model_name,
quantization_config=bnb_config,
use_cache=use_cache,
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
device_map=(
"auto"
if train_config.quantization and not train_config.enable_fsdp
else None
),
torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
)
else:
raise ValueError(
f"Model type {config.model_type} is not supported. Please use llama or mllama model."
)
# Load the tokenizer and add special tokens
tokenizer = AutoTokenizer.from_pretrained(
train_config.model_name
if train_config.tokenizer_name is None
else train_config.tokenizer_name
)
if not tokenizer.pad_token_id:
tokenizer.pad_token_id = tokenizer.eos_token_id
# If there is a mismatch between tokenizer vocab size and embedding matrix,
# throw a warning and then expand the embedding matrix
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
print(
"WARNING: Resizing the embedding matrix to match the tokenizer vocab size."
)
model.resize_token_embeddings(len(tokenizer))
print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
# Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
if (
train_config.enable_fsdp
and fsdp_config.pure_bf16
and not train_config.quantization
):
model.to(torch.bfloat16)
if train_config.use_peft:
# Load the pre-trained peft model checkpoint and setup its configuration
if train_config.from_peft_checkpoint:
model = PeftModel.from_pretrained(
model, train_config.from_peft_checkpoint, is_trainable=True
)
peft_config = model.peft_config
# Generate the peft config and start fine-tuning from original model
else:
peft_config = generate_peft_config(train_config, kwargs)
model = get_peft_model(model, peft_config)
if wandb_run:
wandb_run.config.update(peft_config)
model.print_trainable_parameters()
hsdp_device_mesh_plan = None
if (
fsdp_config.hsdp
and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD
):
hsdp_device_mesh_plan = hsdp_device_mesh(
replica_group_size=fsdp_config.replica_group_size,
sharding_group_size=fsdp_config.sharding_group_size,
)
print("HSDP device mesh is ready")
# setting up FSDP if enable_fsdp is enabled
if train_config.enable_fsdp:
check_fsdp_config(fsdp_config)
if not train_config.use_peft and train_config.freeze_layers:
freeze_transformer_layers(model, train_config.num_freeze_layers)
# print model size and frozen layers after freezing layers
print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama":
freeze_LLM_only(model)
# print model size and frozen layers after freezing layers
print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
# Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
if is_vision:
my_auto_wrapping_policy = fsdp_auto_wrap_policy(
model,
[
MllamaSelfAttentionDecoderLayer,
MllamaCrossAttentionDecoderLayer,
MllamaVisionEncoderLayer,
],
)
else:
# Create the FSDP wrapper for LlamaDecoderLayer in text models
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
device_id = 0
if is_xpu_available():
device_id = torch.xpu.current_device()
elif torch.cuda.is_available():
device_id = torch.cuda.current_device()
if train_config.freeze_LLM_only:
use_orig_params = True
else:
use_orig_params = False
model = FSDP(
model,
auto_wrap_policy=(
my_auto_wrapping_policy if train_config.use_peft else wrapping_policy
),
cpu_offload=(
CPUOffload(offload_params=True)
if fsdp_config.fsdp_cpu_offload
else None
),
mixed_precision=(
mixed_precision_policy if not fsdp_config.pure_bf16 else None
),
sharding_strategy=fsdp_config.sharding_strategy,
device_mesh=hsdp_device_mesh_plan,
device_id=device_id,
limit_all_gathers=True,
sync_module_states=train_config.low_cpu_fsdp,
param_init_fn=(
(
lambda module: module.to_empty(
device=torch.device("cuda"), recurse=False
)
)
if train_config.low_cpu_fsdp and rank != 0
else None
),
use_orig_params=use_orig_params,
)
if fsdp_config.fsdp_activation_checkpointing:
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
apply_fsdp_checkpointing(model)
elif not train_config.quantization and not train_config.enable_fsdp:
if is_xpu_available():
model.to("xpu:0")
elif torch.cuda.is_available():
model.to("cuda")
dataset_config = generate_dataset_config(train_config, kwargs)
if is_vision:
dataset_processer = processor
else:
dataset_processer = tokenizer
# Load and preprocess the dataset for training and validation
dataset_train = get_preprocessed_dataset(
dataset_processer,
dataset_config,
split="train",
)
if not train_config.enable_fsdp or rank == 0:
print(f"--> Training Set Length = {len(dataset_train)}")
dataset_val = get_preprocessed_dataset(
dataset_processer,
dataset_config,
split="test",
)
if not train_config.enable_fsdp or rank == 0:
print(f"--> Validation Set Length = {len(dataset_val)}")
if train_config.batching_strategy == "packing":
if is_vision:
raise ValueError("Packing is not supported for vision datasets")
else:
dataset_train = ConcatDataset(
dataset_train, chunk_size=train_config.context_length
)
train_dl_kwargs = get_dataloader_kwargs(
train_config, dataset_train, dataset_processer, "train"
)
print("length of dataset_train", len(dataset_train))
custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
if custom_data_collator:
print("custom_data_collator is used")
train_dl_kwargs["collate_fn"] = custom_data_collator
# Create DataLoaders for the training and validation dataset
train_dataloader = torch.utils.data.DataLoader(
dataset_train,
num_workers=train_config.num_workers_dataloader,
pin_memory=True,
**train_dl_kwargs,
)
print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
eval_dataloader = None
if train_config.run_validation:
if train_config.batching_strategy == "packing":
if is_vision:
raise ValueError("Packing is not supported for vision datasets")
else:
dataset_val = ConcatDataset(
dataset_val, chunk_size=train_config.context_length
)
val_dl_kwargs = get_dataloader_kwargs(
train_config, dataset_val, dataset_processer, "val"
)
if custom_data_collator:
val_dl_kwargs["collate_fn"] = custom_data_collator
eval_dataloader = torch.utils.data.DataLoader(
dataset_val,
num_workers=train_config.num_workers_dataloader,
pin_memory=True,
**val_dl_kwargs,
)
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
if len(eval_dataloader) == 0:
raise ValueError(
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
)
else:
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
# Initialize the optimizer and learning rate scheduler
if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
optimizer = AnyPrecisionAdamW(
model.parameters(),
lr=train_config.lr,
momentum_dtype=torch.bfloat16,
variance_dtype=torch.bfloat16,
use_kahan_summation=False,
weight_decay=train_config.weight_decay,
)
else:
optimizer = optim.AdamW(
model.parameters(),
lr=train_config.lr,
weight_decay=train_config.weight_decay,
)
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
results = train(
model,
train_dataloader,
eval_dataloader,
tokenizer,
optimizer,
scheduler,
train_config.gradient_accumulation_steps,
train_config,
fsdp_config if train_config.enable_fsdp else None,
local_rank if train_config.enable_fsdp else None,
rank if train_config.enable_fsdp else None,
wandb_run,
)
if not train_config.enable_fsdp or rank == 0:
[print(f"Key: {k}, Value: {v}") for k, v in results.items()]
if train_config.use_wandb:
for k, v in results.items():
wandb_run.summary[k] = v
if __name__ == "__main__":
fire.Fire(main)