-
Notifications
You must be signed in to change notification settings - Fork 377
/
Copy pathfinetune.py
935 lines (883 loc) · 40.1 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
# !/usr/bin/env python
# coding=utf-8
# Copyright 2024 AllenAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# isort: off
import os
os.environ["NCCL_CUMEM_ENABLE"] = "0" # NOQA
try:
import deepspeed
# @vwxyzjn: when importing on CPU-only machines, we get the following error:
# RuntimeError: 0 active drivers ([]). There should only be one.
# so we need to catch the exception and do nothing
# https://github.com/deepspeedai/DeepSpeed/issues/7028
except Exception:
pass
# isort: on
import logging
import math
import os
import shutil
import time
from dataclasses import dataclass, field
from datetime import timedelta
from typing import List, Literal, Optional, Union
import datasets
import torch
import transformers
from accelerate import Accelerator, DataLoaderConfiguration
from accelerate.logging import get_logger
from accelerate.utils import InitProcessGroupKwargs, set_seed
from huggingface_hub import HfApi
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
from rich.pretty import pprint
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import (
AutoConfig,
AutoModelForCausalLM,
BitsAndBytesConfig,
DataCollatorForSeq2Seq,
get_scheduler,
)
from open_instruct.dataset_transformation import (
INPUT_IDS_KEY,
TOKENIZED_SFT_DATASET_KEYS,
TokenizerConfig,
get_cached_dataset_tulu,
visualize_token,
)
from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate
from open_instruct.utils import (
ArgumentParserPlus,
clean_last_n_checkpoints,
get_last_checkpoint_path,
get_wandb_tags,
is_beaker_job,
launch_ai2_evals_on_weka,
maybe_get_beaker_config,
maybe_use_ai2_hf_entity,
maybe_use_ai2_wandb_entity,
)
logger = get_logger(__name__)
@dataclass
class FlatArguments:
"""
Full arguments class for all fine-tuning jobs.
"""
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""The name of this experiment"""
run_name: Optional[str] = None
"""A unique name of this run"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
)
},
)
config_name: Optional[str] = field(
default=None,
metadata={"help": "Pretrained config name or path if not the same as model_name"},
)
use_flash_attn: bool = field(
default=True,
metadata={"help": "Whether to use flash attention in the model training"},
)
model_revision: Optional[str] = field(
default=None,
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
low_cpu_mem_usage: bool = field(
default=False,
metadata={
"help": (
"It is an option to create the model as an empty shell, "
"then only materialize its parameters when the pretrained weights are loaded. "
"set True will benefit LLM loading time and RAM consumption."
)
},
)
dataset_name: Optional[str] = field(
default=None,
metadata={"help": "The name of the dataset to use (via the datasets library)."},
)
dataset_mixer: Optional[dict] = field(
default=None,
metadata={"help": "A dictionary of datasets (local or HF) to sample from."},
)
dataset_mixer_list: List[str] = field(default_factory=lambda: ["allenai/tulu-3-sft-personas-algebra", "1.0"])
"""A list of datasets (local or HF) to sample from."""
dataset_mixer_list_splits: List[str] = field(default_factory=lambda: ["train"])
"""The dataset splits to use for training"""
dataset_transform_fn: list[str] = field(
default_factory=lambda: ["sft_tulu_tokenize_and_truncate_v1", "sft_tulu_filter_v1"]
)
"""The list of transform functions to apply to the dataset."""
dataset_target_columns: List[str] = field(default_factory=lambda: TOKENIZED_SFT_DATASET_KEYS)
"""The columns to use for the dataset."""
dataset_cache_mode: Literal["hf", "local"] = "local"
"""The mode to use for caching the dataset."""
dataset_local_cache_dir: str = "local_dataset_cache"
"""The directory to save the local dataset cache to."""
dataset_config_hash: Optional[str] = None
"""The hash of the dataset configuration."""
dataset_skip_cache: bool = False
"""Whether to skip the cache."""
dataset_mix_dir: Optional[str] = field(
default=None,
metadata={"help": "The directory to save the mixed dataset to disk."},
)
dataset_config_name: Optional[str] = field(
default=None,
metadata={"help": "The configuration name of the dataset to use (via the datasets library)."},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_seq_length: Optional[int] = field(
default=None,
metadata={
"help": (
"The maximum total input sequence length after tokenization. "
"Sequences longer than this will be truncated,"
)
},
)
overwrite_cache: bool = field(
default=False,
metadata={"help": "Overwrite the cached training and evaluation sets"},
)
clip_grad_norm: float = field(
default=-1,
metadata={"help": "Clip gradient norm. Not compatible with deepspeed (use deepspeed config instead)."},
)
gradient_accumulation_steps: int = field(
default=1,
metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."},
)
learning_rate: float = field(
default=2e-5,
metadata={"help": "The initial learning rate for AdamW optimizer."},
)
logging_steps: Optional[int] = field(
default=None,
metadata={"help": "Log the training loss and learning rate every logging_steps steps."},
)
lora_rank: int = field(
default=64,
metadata={"help": "The rank of lora."},
)
lora_alpha: float = field(
default=16,
metadata={"help": "The alpha parameter of lora."},
)
lora_dropout: float = field(
default=0.1,
metadata={"help": "The dropout rate of lora modules."},
)
lr_scheduler_type: str = field(
default="linear",
metadata={
"help": "The scheduler type to use for learning rate adjustment.",
"choices": [
"linear",
"cosine",
"cosine_with_restarts",
"polynomial",
"constant",
"constant_with_warmup",
],
},
)
num_train_epochs: int = field(
default=2,
metadata={"help": "Total number of training epochs to perform."},
)
output_dir: str = field(
default="output/",
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
)
per_device_train_batch_size: int = field(
default=8,
metadata={"help": "Batch size per GPU/TPU core/CPU for training."},
)
use_lora: bool = field(
default=False,
metadata={"help": "If True, will use LORA (low-rank parameter-efficient training) to train the model."},
)
use_qlora: bool = field(
default=False,
metadata={"help": "Use qLoRA training - initializes model in quantized form. Not compatible with deepspeed."},
)
use_8bit_optimizer: bool = field(
default=False,
metadata={"help": "Use 8bit optimizer from bitsandbytes. Not compatible with deepspeed."},
)
warmup_ratio: float = field(
default=0.03,
metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."},
)
weight_decay: float = field(
default=0.0,
metadata={"help": "Weight decay for AdamW if we apply some."},
)
timeout: int = field(
default=1800,
metadata={
"help": "Timeout for the training process in seconds."
"Useful if tokenization process is long. Default is 1800 seconds (30 minutes)."
},
)
reduce_loss: str = field(
default="mean",
metadata={
"help": "How to reduce loss over tokens. Options are 'mean' or 'sum'."
"Using 'sum' can improve chat model performance."
},
)
resume_from_checkpoint: Optional[str] = field(
default=None,
metadata={"help": "If the training should continue from a checkpoint folder."},
)
report_to: Union[str, List[str]] = field(
default="all",
metadata={
"help": "The integration(s) to report results and logs to. "
"Can be a single string or a list of strings. "
"Options are 'tensorboard', 'wandb', 'comet_ml', 'clearml', or 'all'. "
"Specify multiple by listing them: e.g., ['tensorboard', 'wandb']"
},
)
save_to_hub: Optional[str] = field(
default=None,
metadata={"help": "Save the model to the Hub under this name. E.g allenai/your-model"},
)
gradient_checkpointing: bool = field(
default=False,
metadata={"help": "Turn on gradient checkpointing. Saves memory but slows training."},
)
use_liger_kernel: bool = field(
default=False,
metadata={"help": "Whether to use LigerKernel for training."},
)
max_train_steps: Optional[int] = field(
default=None,
metadata={"help": "If set, overrides the number of training steps. Otherwise, num_train_epochs is used."},
)
seed: int = field(
default=42,
metadata={"help": "Random seed for initialization and dataset shuffling."},
)
checkpointing_steps: Optional[str] = field(
default=None,
metadata={
"help": "Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch." # noqa
},
)
keep_last_n_checkpoints: int = field(
default=3,
metadata={"help": "How many checkpoints to keep in the output directory. -1 for all."},
)
fused_optimizer: bool = field(
default=True,
metadata={
"help": "Whether to use fused AdamW or not.",
},
)
load_balancing_loss: bool = field(
default=False,
metadata={
"help": "Whether to include a load balancing loss (for OLMoE) or not.",
},
)
load_balancing_weight: float = field(
default=0.5,
metadata={"help": "Weight for load balancing loss if applicable."},
)
# Experiment tracking
with_tracking: bool = False
"""If toggled, this experiment will be tracked with Weights and Biases"""
wandb_project_name: str = "open_instruct_internal"
"""The wandb's project name"""
wandb_entity: Optional[str] = None
"""The entity (team) of wandb's project"""
push_to_hub: bool = True
"""Whether to upload the saved model to huggingface"""
hf_entity: Optional[str] = None
"""The user or org name of the model repository from the Hugging Face Hub"""
hf_repo_id: Optional[str] = None
"""The id of the saved model in the Hugging Face Hub (can be autoset if not given)"""
hf_repo_revision: Optional[str] = None
"""The revision of the saved model in the Hugging Face Hub (can be autoset if not given)"""
hf_repo_url: Optional[str] = None
"""The url of the saved model in the Hugging Face Hub (will be autoset)"""
try_launch_beaker_eval_jobs: bool = True
"""Whether to launch beaker evaluation jobs after training"""
hf_metadata_dataset: Optional[str] = "allenai/tulu-3-evals"
"""What dataset to upload the metadata to. If unset, don't upload metadata"""
cache_dataset_only: bool = False
"""Immediately exit after caching the dataset"""
# Ai2 specific settings
try_auto_save_to_beaker: bool = True
"""Whether to try to save the model to Beaker dataset `/output` after training"""
gs_bucket_path: Optional[str] = None
"""The path to the gs bucket to save the model to"""
oe_eval_tasks: Optional[List[str]] = None
"""The beaker evaluation tasks to launch"""
oe_eval_max_length: int = 4096
"""the max generation length for evaluation for oe-eval"""
def __post_init__(self):
if self.reduce_loss not in ["mean", "sum"]:
raise ValueError("reduce_loss must be either 'mean' or 'sum'")
if self.dataset_name is None and self.dataset_mixer is None and self.dataset_mixer_list is None:
raise ValueError("Need either a dataset name, dataset mixer, or dataset mixer list.")
if (
(self.dataset_name is not None and (self.dataset_mixer is not None or self.dataset_mixer_list is not None))
or (self.dataset_name is not None)
or (self.dataset_mixer is not None and self.dataset_mixer_list is not None)
):
raise ValueError("Cannot provide two dataset selection mechanisms.")
if self.try_launch_beaker_eval_jobs and not self.push_to_hub:
raise ValueError("Cannot launch Beaker evaluation jobs without pushing to the Hub.")
def main(args: FlatArguments, tc: TokenizerConfig):
# ------------------------------------------------------------
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
# in the environment
accelerator_log_kwargs = {}
if args.with_tracking:
accelerator_log_kwargs["log_with"] = args.report_to
accelerator_log_kwargs["project_dir"] = args.output_dir
# if you get timeouts (e.g. due to long tokenization) increase this.
timeout_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.timeout))
dataloader_config = DataLoaderConfiguration(use_seedable_sampler=True)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
dataloader_config=dataloader_config,
**accelerator_log_kwargs,
kwargs_handlers=[timeout_kwargs],
)
# ------------------------------------------------------------
# Setup tokenizer
tc.tokenizer_revision = args.model_revision if tc.tokenizer_revision is None else tc.tokenizer_revision
tc.tokenizer_name_or_path = (
args.model_name_or_path if tc.tokenizer_name_or_path is None else tc.tokenizer_name_or_path
)
if tc.tokenizer_revision != args.model_revision and tc.tokenizer_name_or_path != args.model_name_or_path:
# Warn user if tokenizer and model use different revisions; this is an unusual
# use case.
warning = f"""Requested tokenizer revision `{tc.tokenizer_revision=}` is different
from the model revision `{args.model_revision=}` or the tokenizer name `{tc.tokenizer_name_or_path=}`
is different from the model name `{args.model_name_or_path=}`."""
logger.warning(warning)
tokenizer = tc.tokenizer
# ------------------------------------------------------------
# Set up runtime variables
args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}"
args.output_dir = os.path.join(args.output_dir, args.run_name)
args.dataset_local_cache_dir = os.path.abspath(args.dataset_local_cache_dir)
if is_beaker_job():
args.dataset_local_cache_dir = "/weka/oe-adapt-default/allennlp/deletable_open_instruct_dataset_cache"
if args.push_to_hub and accelerator.is_main_process:
if args.hf_repo_id is None: # auto-generate one
args.hf_repo_id = "open_instruct_dev"
if args.hf_entity is None: # first try to use AI2 entity
args.hf_entity = maybe_use_ai2_hf_entity()
if args.hf_entity is None: # then try to use the user's entity
args.hf_entity = HfApi().whoami()["name"]
args.hf_repo_id = f"{args.hf_entity}/{args.hf_repo_id}"
if args.hf_repo_revision is None:
args.hf_repo_revision = args.run_name
args.hf_repo_url = f"https://huggingface.co/{args.hf_repo_id}/tree/{args.hf_repo_revision}"
if is_beaker_job():
beaker_config = maybe_get_beaker_config()
# ------------------------------------------------------------
# Initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if args.with_tracking:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"]
# (Optional) Ai2 internal tracking
if args.wandb_entity is None:
args.wandb_entity = maybe_use_ai2_wandb_entity()
if accelerator.is_main_process and is_beaker_job():
experiment_config.update(vars(beaker_config))
experiment_config.update(vars(tc))
accelerator.init_trackers(
args.wandb_project_name,
experiment_config,
init_kwargs={
"wandb": {
"name": args.run_name,
"entity": args.wandb_entity,
"tags": [args.exp_name] + get_wandb_tags(),
}
},
)
wandb_tracker = accelerator.get_tracker("wandb")
if accelerator.is_main_process:
pprint([args, tc])
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
if args.dataset_mixer is not None:
args.dataset_mixer_list = [item for pair in args.dataset_mixer.items() for item in pair]
with accelerator.main_process_first():
transform_fn_args = [
{"max_seq_length": args.max_seq_length},
{},
]
train_dataset = get_cached_dataset_tulu(
dataset_mixer_list=args.dataset_mixer_list,
dataset_mixer_list_splits=args.dataset_mixer_list_splits,
tc=tc,
dataset_transform_fn=args.dataset_transform_fn,
transform_fn_args=transform_fn_args,
target_columns=args.dataset_target_columns,
dataset_cache_mode=args.dataset_cache_mode,
dataset_config_hash=args.dataset_config_hash,
hf_entity=args.hf_entity,
dataset_local_cache_dir=args.dataset_local_cache_dir,
dataset_skip_cache=args.dataset_skip_cache,
)
train_dataset = train_dataset.shuffle(seed=args.seed)
train_dataset.set_format(type="pt")
if accelerator.is_main_process:
visualize_token(train_dataset[0][INPUT_IDS_KEY], tokenizer)
if args.cache_dataset_only:
return
# Load pretrained model and tokenizer
if args.config_name:
config = AutoConfig.from_pretrained(
args.config_name,
revision=args.model_revision,
trust_remote_code=tc.trust_remote_code,
)
elif args.model_name_or_path:
config = AutoConfig.from_pretrained(
args.model_name_or_path,
revision=args.model_revision,
trust_remote_code=tc.trust_remote_code,
)
else:
raise ValueError(
"You are instantiating a new config instance from scratch. This is not supported by this script."
)
if args.model_name_or_path:
if args.use_qlora:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
device_index = accelerator.local_process_index
device_map = {"": device_index} # force data-parallel training.
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
revision=args.model_revision,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
trust_remote_code=tc.trust_remote_code,
quantization_config=bnb_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2" if args.use_flash_attn else "eager",
)
elif args.use_liger_kernel:
from liger_kernel.transformers import AutoLigerKernelForCausalLM
fused_linear_cross_entropy = args.reduce_loss == "mean"
logger.info(f"Attempting to apply liger-kernel. {fused_linear_cross_entropy=}")
# Supported models: https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/monkey_patch.py#L948
model = AutoLigerKernelForCausalLM.from_pretrained(
args.model_name_or_path,
revision=args.model_revision,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
trust_remote_code=tc.trust_remote_code,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_flash_attention_2=True if args.use_flash_attn else False,
# liger-kernel specific args
fused_linear_cross_entropy=fused_linear_cross_entropy,
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
revision=args.model_revision,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
trust_remote_code=tc.trust_remote_code,
low_cpu_mem_usage=args.low_cpu_mem_usage,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2" if args.use_flash_attn else "eager",
)
else:
logger.info("Training new model from scratch")
model = AutoModelForCausalLM.from_config(config)
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
# gather deepspeed to get "real" embedding size
embeddings = model.get_input_embeddings()
with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None):
embedding_size = embeddings.weight.shape[0]
# resize does its own gather
if len(tokenizer) > embedding_size:
# pad to multiple for tensor cores.
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
# update embedding size after resizing for sum loss
embeddings = model.get_input_embeddings()
with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None):
embedding_size = embeddings.weight.shape[0]
if args.use_lora:
if args.use_qlora:
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing)
logger.info("Initializing LORA model...")
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=[
"q_proj",
"o_proj",
"v_proj",
"k_proj",
"gate_proj",
"up_proj",
"down_proj",
],
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
elif args.gradient_checkpointing:
model.gradient_checkpointing_enable()
# DataLoaders creation:
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest"),
batch_size=args.per_device_train_batch_size,
)
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "layer_norm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
if args.use_qlora:
from bitsandbytes.optim import AdamW
optimizer = AdamW(
optimizer_grouped_parameters,
lr=args.learning_rate,
optim_bits=8 if args.use_8bit_optimizer else 32,
is_paged=True,
)
else:
optimizer = torch.optim.AdamW(
optimizer_grouped_parameters,
lr=args.learning_rate,
fused=args.fused_optimizer,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
# Create the learning rate scheduler.
# Note: the current accelerator.step() calls the .step() of the real scheduler
# for the `num_processes` times. This is because they assume
# the user initialize the scheduler with the entire training set.
# In the case of data parallel training, each process only
# sees a subset (1/num_processes) of the training set.
# So each time the process needs to update the lr multiple times so that the total
# number of updates in the end matches the num_training_steps here.
# Here we need to set the num_training_steps to either using the
# entire training set (when epochs is specified) or we need to multiply the
# num_training_steps by num_processes so that the total number of
# updates matches the num_training_steps.
num_training_steps_for_scheduler = (
args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes
)
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_training_steps=num_training_steps_for_scheduler,
num_warmup_steps=int(num_training_steps_for_scheduler * args.warmup_ratio),
)
# Prepare everything with `accelerator`.
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states
checkpointing_steps = args.checkpointing_steps
if checkpointing_steps is not None and str(checkpointing_steps).lower() != "epoch":
checkpointing_steps = int(checkpointing_steps)
# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save
last_checkpoint_path = get_last_checkpoint_path(args)
if last_checkpoint_path:
accelerator.print(f"Resumed from checkpoint: {last_checkpoint_path}")
accelerator.load_state(last_checkpoint_path)
# Extract `epoch_{i}` or `step_{i}`
last_checkpoint_path = os.path.basename(last_checkpoint_path)
training_difference = os.path.splitext(last_checkpoint_path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
completed_steps = starting_epoch * num_update_steps_per_epoch
else:
# need to multiply `gradient_accumulation_steps` to reflect real steps
resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps
starting_epoch = resume_step // len(train_dataloader)
completed_steps = resume_step // args.gradient_accumulation_steps
resume_step -= starting_epoch * len(train_dataloader)
print(f"Starting from epoch {starting_epoch} and step {completed_steps}.")
# update the progress_bar if load from checkpoint
progress_bar.update(completed_steps)
local_total_tokens = torch.tensor(0, dtype=torch.int64, device=accelerator.device)
total_token_including_padding = torch.tensor(0, dtype=torch.int64, device=accelerator.device)
start_time = time.time()
for epoch in range(starting_epoch, args.num_train_epochs):
model.train()
train_dataloader.set_epoch(epoch)
total_loss = 0
total_aux_loss = 0
if last_checkpoint_path and resume_step is not None:
# We skip the first `n` batches in the dataloader when resuming from a checkpoint
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
else:
active_dataloader = train_dataloader
for step, batch in enumerate(active_dataloader):
local_total_tokens += batch["attention_mask"].sum()
total_token_including_padding += batch["attention_mask"].numel()
with accelerator.accumulate(model):
if args.load_balancing_loss:
outputs = model(**batch, use_cache=False, output_router_logits=True)
else:
# TODO: we have calculated the mean loss here anyway, so doubling the calculation
outputs = model(**batch, use_cache=False)
if args.reduce_loss == "mean":
loss = outputs.loss
else:
# reduce loss is sum
# this ensures that we weight all tokens in the dataset equally,
# rather than weighting each overall example equally when
# using high amounts of gradient accumulation.
# this can result in > 5 point improvements in AlpacaEval
# see https://github.com/huggingface/transformers/issues/24725 for
# more discussion and details.
logits = outputs.logits
labels = batch["labels"]
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss(reduction="sum")
shift_logits = shift_logits.view(-1, embedding_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if args.load_balancing_loss:
aux_loss = args.load_balancing_weight * outputs.aux_loss
loss += aux_loss
# We keep track of the loss at each logged step
total_loss += loss.detach().float()
accelerator.backward(loss)
if args.load_balancing_loss:
total_aux_loss += aux_loss.detach().float()
# clip gradient norm. don't do this with deepspeed
if accelerator.sync_gradients and args.clip_grad_norm > 0:
accelerator.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
completed_steps += 1
if args.logging_steps and completed_steps % args.logging_steps == 0:
avg_loss = (
accelerator.gather(total_loss).mean().item()
/ args.gradient_accumulation_steps
/ args.logging_steps
)
total_tokens = accelerator.gather(local_total_tokens).sum().item()
total_tokens_including_padding = accelerator.gather(total_token_including_padding).sum().item()
metrics_to_log = {
"learning_rate": lr_scheduler.get_last_lr()[0],
"train_loss": avg_loss,
"total_tokens": total_tokens,
"per_device_tps": total_tokens / accelerator.num_processes / (time.time() - start_time),
"total_tokens_including_padding": total_tokens_including_padding,
"per_device_tps_including_padding": total_tokens_including_padding
/ accelerator.num_processes
/ (time.time() - start_time),
}
if args.load_balancing_loss:
avg_aux_loss = (
accelerator.gather(total_aux_loss).mean().item()
/ args.gradient_accumulation_steps
/ args.logging_steps
)
logger.info(
f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}, Aux Loss: {avg_aux_loss}, TPS: {total_tokens / (time.time() - start_time)}"
)
metrics_to_log["aux_loss"] = avg_aux_loss
else:
logger.info(
f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}, TPS: {total_tokens / (time.time() - start_time)}"
)
if args.with_tracking:
accelerator.log(
metrics_to_log,
step=completed_steps,
)
total_loss = 0
total_aux_loss = 0
if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
# use this to mark the checkpoint as completely saved, to avoid restoring from garbled checkpoints
with open(
os.path.join(
get_last_checkpoint_path(args, incomplete=True),
"COMPLETED",
),
"w",
) as f:
f.write("COMPLETED") # annoyingly, empty files arent uploaded by beaker.
if (
accelerator.is_local_main_process
): # TODO: in mason local model this is gonna error out if using something like output/test; because mason used the same shared file ssytem.
clean_last_n_checkpoints(args.output_dir, args.keep_last_n_checkpoints)
accelerator.wait_for_everyone()
if completed_steps >= args.max_train_steps:
break
if checkpointing_steps == "epoch":
output_dir = f"epoch_{epoch}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
# use this to mark the checkpoint as completely saved, to avoid restoring from garbled checkpoints
with open(
os.path.join(get_last_checkpoint_path(args, incomplete=True), "COMPLETED"),
"w",
) as f:
f.write("COMPLETED") # annoyingly, empty files arent uploaded by beaker.
if accelerator.is_local_main_process:
clean_last_n_checkpoints(args.output_dir, args.keep_last_n_checkpoints)
accelerator.wait_for_everyone()
if args.output_dir is not None:
save_with_accelerate(
accelerator,
model,
tokenizer,
args.output_dir,
args.use_lora,
)
# remove all checkpoints to save space
if accelerator.is_local_main_process:
clean_last_n_checkpoints(args.output_dir, keep_last_n_checkpoints=0)
if (
args.try_auto_save_to_beaker
and accelerator.is_main_process
and is_beaker_job()
and len(beaker_config.beaker_dataset_id_urls) > 0
and args.output_dir.rstrip("/") != "/output"
):
shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True)
if is_beaker_job() and accelerator.is_main_process and args.try_launch_beaker_eval_jobs:
launch_ai2_evals_on_weka(
path=args.output_dir,
leaderboard_name=args.hf_repo_revision,
oe_eval_max_length=args.oe_eval_max_length,
wandb_url=wandb_tracker.run.get_url(),
oe_eval_tasks=args.oe_eval_tasks,
gs_bucket_path=args.gs_bucket_path,
)
if args.push_to_hub:
push_folder_to_hub(
accelerator,
args.output_dir,
args.hf_repo_id,
args.hf_repo_revision,
)
accelerator.wait_for_everyone()
if args.with_tracking:
accelerator.end_training()
if __name__ == "__main__":
parser = ArgumentParserPlus((FlatArguments, TokenizerConfig))
args, tc = parser.parse_args_into_dataclasses()
main(args, tc)