Skip to content

LoRA & other fine-tuning methods for LLaMA

Notifications You must be signed in to change notification settings

cedrickchee/minimal-llama

 
 

Repository files navigation

Minimal LLaMA

This repo contains a random assortment of code for running and fine-tuning LLaMA. Many parts are still work in progress. There ought to be more efficient methods of tuning (DeepSpeed / ZeRO, NeoX) than the ones presented here, but folks may find this useful already.

This code was fairly quickly thrown together and may contains many, many bugs. Feedback is welcome!

Tokenize datasets

First, we tokenize the data so we never have to worry about the tokenizer again. The tokenization script takes in a JSONL (each row containing the key "text" for the document text), and effectively concatenates, tokenizes, and slices into max_seq_length chunks.

(This is a quick and dirty script that loads the whole dataset into memory.)

python tokenize_dataset.py \
    --tokenizer_path /path/to/tokenizer \
    --jsonl_path /path/to/data.jsonl \
    --save_path /path/to/tokenized_dataset \
    --max_seq_length 512

PEFT Fine-tuning with 8-bit

Requires using the Transformers PR here, based on the fork here. Model weights need to be converted to HF format using the weight conversion script in the PR.

Requires using the PEFT PR here, based on the fork here.

We can fine-tune using the PEFT library, with the model converted to 8-bit. This is based on the guide here.

python finetune_peft.py \
    --model_path /path/to/llama-7b/ \
    --dataset_path /path/to/tokenized_dataset \
    --peft_mode lora \
    --lora_rank 8 \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 1 \
    --max_steps 2500 \
    --learning_rate 2e-4 \
    --fp16 \
    --logging_steps 10 \
    --output_dir /path/to/save

The above configuration (with max_seq_length=512) uses about 20GB of RAM on a single GPU. (With bs=1 and max_seq_length=256, this gets down to about 12 GB.)

You can generate using the trained PEFT params using something like the following:

import torch
import transformers
from finetune_peft import get_peft_config, PEFTArguments
from peft import get_peft_model

model_path = ...
peft_path = ...
tokenizer_path = ...

torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = transformers.LLaMAForCausalLM.from_pretrained(model_path)
peft_config = get_peft_config(peft_args=PEFTArguments(peft_mode="lora"))
model = get_peft_model(model, peft_config)
model.load_state_dict(torch.load(peft_path), strict=False)
torch.set_default_tensor_type(torch.cuda.FloatTensor)

tokenizer = transformers.LLaMATokenizer.from_pretrained(tokenizer_path)
batch = tokenizer("The LLaMA language model is", return_tensors="pt")

with torch.no_grad():
    out = model.generate(
        input_ids=batch["input_ids"],
        attention_mask=torch.ones_like(batch["input_ids"]),
        max_length=200,
    )
print(tokenizer.decode(out[0]))

Fine-tuning with Naive Pipeline Parallel

Requires using the Transformers PR here, based on the fork here. Model weights need to be converted to HF format using the weight conversion script in the PR.

For fully fine-tuning (larger) models, we can use (a very naively implemented version of) pipeline parallelism. This is preferable for larger models that won't fit on a single GPU.

python finetune_pp.py \
    --model_path /path/to/llama-7b/ \
    --dataset_path /path/to/tokenized_dataset \
    --save_dir /path/to/save \
    --batch_size 4 \
    --gradient_accumulation_steps 2 \
    --save_interval 2000 \
    --num_train_steps 20000

The above configuration uses about 30-35GB of RAM per GPU across 8 GPUs.

PEFT Fine-tuning with 8-bit and Pipeline Parallel

Seems buggy, don't use this yet.

Requires using the Transformers PR here, based on the fork here. Model weights need to be converted to HF format using the weight conversion script in the PR.

Requires using the PEFT PR here, based on the fork here.

Here, we combine PEFT training with pipeline parallel to train with large models. See PEFT Fine-tuning with 8-bit for more details.

python finetune_pp_peft.py \
    --model_path /path/to/llama-30b/ \
    --dataset_path /path/to/tokenized_dataset \
    --save_dir /path/to/save \
    --batch_size 4 \
    --learning_rate 5e-5 \
    --gradient_accumulation_steps 1 \
    --save_interval 2000 \
    --num_train_steps 20000 \
    --peft_mode lora \
    --lora_rank 8

For instance, you can fine-tune LoRA on 65B LLaMA with about 120GB of memory in total (e.g. 15GB each on 8 GPUs, or 60GB on 2 GPUs) with batch size=1 and sequence length = 512.

Misc Notes

  • I have no idea what hyperparameters are best for fine-tuning.
  • Aside from model parameters + gradients + optimizer states, the hidden activations also take up a big chunk of memory. Shortening the max_sequence_length is a good way of reducing memory consumption. I don't really know how much that affects fine-tuning performance either.

About

LoRA & other fine-tuning methods for LLaMA

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%