Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions FAQS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# FAQs

- Can you train StableLM with this? Yes, but only with a single GPU atm. Multi GPU support is coming soon! Just waiting on this [PR](https://github.com/huggingface/transformers/pull/22874)
- Will this work with Deepspeed? That's still a WIP, but setting `export ACCELERATE_USE_DEEPSPEED=true` should work in some cases
41 changes: 41 additions & 0 deletions configs/galactica_1_3B.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
base_model: facebook/galactica-1.3b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
adapter:
lora_model_dir:
sequence_len: 1024
max_packed_sequence_len: 1024
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
lora_fan_in_fan_out: false
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model: checkpoint
output_dir: ./lora-llama-alpaca
batch_size: 32
micro_batch_size: 16
num_epochs: 3
learning_rate: 0.00003
train_on_inputs: false
group_by_length: false
bf16: false
tf32: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
special_tokens:
pad_token: "[PAD]"
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
39 changes: 39 additions & 0 deletions configs/llama_13B_alpaca.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
base_model: huggyllama/llama-13b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: true
datasets:
- path: anon8231489123/ShareGPT_Vicuna_unfiltered
data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
type: sharegpt
dataset_prepared_path: last_run_prepared
val_set_size: 0.002
adapter:
lora_model_dir:
sequence_len: 2048
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
lora_fan_in_fan_out: false
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model: checkpoint
output_dir: ./llama-13b-sharegpt
batch_size: 64
micro_batch_size: 2
warmup_steps: 1000
save_steps:
eval_steps:
num_epochs: 5
learning_rate: 0.00003
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
early_stopping_patience: 5
resume_from_checkpoint:
local_rank:
5 changes: 4 additions & 1 deletion configs/llama_65B_alpaca.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ load_in_8bit: true
datasets:
- path: data/alpaca_data_gpt4.jsonl
type: alpaca
- path: data/vicuna_cleaned.jsonl
- path: anon8231489123/ShareGPT_Vicuna_unfiltered
data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
type: sharegpt
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
type: gpteacher
Expand All @@ -30,6 +31,8 @@ wandb_log_model: checkpoint
output_dir: ./lora-llama-alpaca
batch_size: 128
micro_batch_size: 16
warmup_steps: 1000
save_steps:
num_epochs: 5
learning_rate: 0.00003
train_on_inputs: false
Expand Down
33 changes: 33 additions & 0 deletions configs/stability_3b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
base_model: stabilityai/stablelm-base-alpha-3b
load_in_8bit: true
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.04
adapter:
lora_model_dir:
sequence_len: 4096
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
lora_fan_in_fan_out: false
wandb_project: stable-llama-3b
wandb_watch:
wandb_run_id:
wandb_log_model: checkpoint
output_dir: ./stable-llama-3b
batch_size: 128
micro_batch_size: 16
num_epochs: 1
learning_rate: 0.00003
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
early_stopping_patience: 3
resume_from_checkpoint:
local_rank:
10 changes: 5 additions & 5 deletions ds_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@
"min_loss_scale": 1
},
"scheduler": {
"type": "WarmupLR",
"type": "OneCycle",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
"cycle_min_lr": 1e-7,
"cycle_max_lr": 1e-4
}
},
"zero_optimization": {
Expand All @@ -25,7 +24,8 @@
"allgather_bucket_size": 5e8,
"contiguous_gradients": true,
"reduce_bucket_size": "auto",
"reduce_scatter": true
"reduce_scatter": true,
"stage3_gather_16bit_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
Expand Down
2 changes: 1 addition & 1 deletion scripts/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def train(
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg)
cfg.ddp = cfg.world_size != 1
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.gradient_accumulation_steps = (
Expand Down
14 changes: 9 additions & 5 deletions src/axolotl/datasets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import List

import torch
Expand Down Expand Up @@ -92,11 +93,14 @@ def __iter__(self):
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
}
if labels.size() == input_ids.size() and attention_mask.size() == input_ids.size():
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
}
else:
logging.warning("dropping batch due to tensor size mismatch")
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
buffer_len = 0

Expand Down
4 changes: 4 additions & 0 deletions src/axolotl/prompters.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ def append_message(self, role, message):

class ShareGPTPrompter:
def build_prompt(self, source, tokenizer):
# ignore the system prompt if provided
if source[0]["from"] == "system":
source.pop(0)

if len(source) < 2:
# If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations
Expand Down
45 changes: 33 additions & 12 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from hashlib import md5
from pathlib import Path

from datasets import load_from_disk, load_dataset, IterableDataset, Dataset
from datasets import load_from_disk, load_dataset, IterableDataset, Dataset, concatenate_datasets
from huggingface_hub import hf_hub_download

from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
from axolotl.prompt_tokenizers import (
Expand Down Expand Up @@ -30,7 +31,7 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
ds_hash = str(
md5(
(
str(max_packed_sequence_len)
str(cfg.sequence_len)
+ "@"
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
).encode("utf-8")
Expand All @@ -43,13 +44,15 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
)

if any(prepared_ds_path.glob("*")):
logging.info("Loading prepared dataset from disk...")
logging.info(f"Loading prepared dataset from disk ay {prepared_ds_path}...")
dataset = load_from_disk(str(prepared_ds_path))
logging.info("Prepared dataset loaded from disk...")
else:
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
logging.info("Loading raw datasets...")
datasets = []
for d in cfg.datasets:
ds = None
ds_from_hub = False
try:
load_dataset(d.path, streaming=True)
Expand All @@ -63,8 +66,14 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
"json", data_files=d.path, streaming=True, split=None
)
elif ds_from_hub:
ds = load_dataset(d.path, streaming=True)
if d.data_files:
ds = load_dataset(d.path, streaming=True, data_files=d.data_files)
else:
ds = load_dataset(d.path, streaming=True)
else:
fp = hf_hub_download(repo_id=d.path, repo_type="dataset", filename=d.data_files)
ds = load_dataset("json", data_files=fp, streaming=True, split=None)
if not ds:
raise Exception("unhandled dataset load")

if d.type == "alpaca":
Expand Down Expand Up @@ -105,20 +114,32 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
datasets.append(ds_wrapper)
else:
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
logging.info("tokenizing, merging, and shuffling master dataset")

samples = []
for d in datasets:
samples = samples + [i for i in d]
dataset = Dataset.from_list(samples).shuffle(seed=42)
if cfg.local_rank == 0:
logging.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(prepared_ds_path)

if cfg.max_packed_sequence_len is not None:
constant_len_dataset = ConstantLengthDataset(
tokenizer,
datasets,
[dataset],
seq_length=max_packed_sequence_len,
)
logging.info("merging, packing, shuffling, and splitting master dataset")
dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split(
test_size=cfg.val_set_size, shuffle=True, seed=42
)
logging.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}")
dataset = Dataset.from_list([_ for _ in constant_len_dataset])

if cfg.local_rank == 0:
logging.info(f"Saving prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(prepared_ds_path)
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
logging.info(f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards")
dataset = dataset.shard(num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx)

dataset = dataset.train_test_split(
test_size=cfg.val_set_size, shuffle=False
)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]

Expand Down
39 changes: 31 additions & 8 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@
import transformers
from transformers import (
AutoModelForCausalLM,
LlamaForCausalLM,
LlamaTokenizer,
AutoTokenizer,
PreTrainedModel,
)
try:
from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
)
except:
logging.warning("This version of transformers does not support Llama. Consider upgrading.")

from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN

Expand Down Expand Up @@ -70,7 +75,7 @@ def load_model(
snapshot_download_kwargs = {}
if cfg.base_model_ignore_patterns:
snapshot_download_kwargs["ignore_patterns"] = cfg.base_model_ignore_patterns
cache_model_path = Path(snapshot_download(base_model, ** snapshot_download_kwargs))
cache_model_path = Path(snapshot_download(base_model, **snapshot_download_kwargs))
files = (
list(cache_model_path.glob("*.pt"))
+ list(cache_model_path.glob("*.safetensors"))
Expand All @@ -95,15 +100,29 @@ def load_model(
else True,
)
load_in_8bit = False
elif is_llama_derived_model:
model = LlamaForCausalLM.from_pretrained(
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
if not cfg.load_in_8bit:
model = LlamaForCausalLM.from_pretrained(
base_model,
device_map=cfg.device_map,
)
else:
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
)

elif model_type:
model = getattr(transformers, model_type).from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
)
else:
model = getattr(transformers, model_type).from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
torch_dtype=torch_dtype,
Expand All @@ -123,7 +142,7 @@ def load_model(

if not tokenizer:
try:
if is_llama_derived_model:
if is_llama_derived_model and "LlamaTokenizer" in globals():
tokenizer = LlamaTokenizer.from_pretrained(model)
else:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
Expand All @@ -142,13 +161,17 @@ def load_model(
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false"

if cfg.special_tokens:
for k, v in cfg.special_tokens.items():
setattr(tokenizer, k, v)

if load_in_8bit and not cfg.load_4bit:
logging.info("converting model w/ prepare_model_for_int8_training")
model = prepare_model_for_int8_training(model)

model, lora_config = load_adapter(model, cfg, adapter)

if cfg.ddp:
if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}")

if cfg.load_4bit:
Expand Down
Loading