Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot apply both PEFT QLoRA and DeepSpeed ZeRO3 #2016

Closed
4 tasks
echo-yi opened this issue Aug 19, 2024 · 9 comments
Closed
4 tasks

Cannot apply both PEFT QLoRA and DeepSpeed ZeRO3 #2016

echo-yi opened this issue Aug 19, 2024 · 9 comments

Comments

@echo-yi
Copy link

echo-yi commented Aug 19, 2024

System Info

- `Accelerate` version: 0.33.0
- Platform: Linux-5.15.133+-x86_64-with-glibc2.35
- `accelerate` bash location: /opt/conda/bin/accelerate
- Python version: 3.10.14
- Numpy version: 1.25.2
- PyTorch version (GPU?): 2.2.0+cu121 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- PyTorch MUSA available: False
- System RAM: 1842.60 GB
- GPU type: NVIDIA H100 80GB HBM3
- `Accelerate` default config:
	- compute_environment: LOCAL_MACHINE
	- distributed_type: NO
	- mixed_precision: no
	- use_cpu: True
	- debug: False
	- num_processes: 1
	- machine_rank: 0
	- num_machines: 1
	- rdzv_backend: static
	- same_network: False
	- main_training_function: main
	- enable_cpu_affinity: False
	- downcast_bf16: False
	- tpu_use_cluster: False
	- tpu_use_sudo: False

Who can help?

@stevhliu

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

This line model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-405B-Instruct", ...) throws CUDA OOM, because the parameters are not partitoned, but copied across the GPUs.

command
accelerate launch --config_file zero3_config.yaml pretrain.py --num_processes=8 --multi_gpu

pretrain.py

...
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    # bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_storage=torch.bfloat16,
)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3.1-405B-Instruct",
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
)
...

zero3_config.yaml

compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  deepspeed_multinode_launcher: standard
  offload_optimizer_device: none
  offload_param_device: none
  zero3_init_flag: true
  zero3_save_16bit_model: true
  zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Expected behavior

PEFT QLoRA (with BitsandBytes) and DeepSpeed ZeRO3 are both applied, so that model parameters are quantized and partitoned.
I thought this should be working according to this post, but microsoft/DeepSpeed#5819 says BitsandBytes quantization and ZeRO3 are not compatible. If this is the case, I find the above post quite misleading.

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Aug 19, 2024

@echo-yi Does it work for you with a smaller model, like the example from the PEFT docs?

@matthewdouglas @Titus-von-Koeller Could you please take a look, could it be an issue specifically with Llama 405B?

More context: huggingface/transformers#29587

Disable zero.init when using DeepSpeed with QLoRA.

I wonder if this is still needed?

@echo-yi
Copy link
Author

echo-yi commented Aug 19, 2024

@BenjaminBossan I tried with "meta-llama/Meta-Llama-3.1-8B-Instruct", "meta-llama/Meta-Llama-3.1-70B-Instruct" and neither worked.

@BenjaminBossan
Copy link
Member

Thanks for testing those. Since this error occurs already at the stage of loading the base model, it is not directly a PEFT error, though of course PEFT is affected and I'd be ready to update the docs if it is confirmed that DS ZeRO3 doesn't work with bnb. I hope the bnb authors can elucidate us.

@echo-yi
Copy link
Author

echo-yi commented Aug 20, 2024

@tjruwase from deepspeed shared this line, indicating applying both quantization and DS ZeRO3 doesn't work.

@BenjaminBossan
Copy link
Member

shared this line, indicating applying both quantization and DS ZeRO3 doesn't work

Yeah, that was added in the PR I mentioned earlier.

I can confirm that even for smaller models, partitioning does not appear to work. But when I remove quantization and use device_map="auto", the same picture emerges. So I'm actually unsure if there is an issue here with bitsandbytes usage in DeepSpeed ZeRO3 or if something else is amiss.

@echo-yi
Copy link
Author

echo-yi commented Aug 21, 2024

@BenjaminBossan When I remove ZeRO3 and use quantization & device_map="auto", partitoning does appear to work.

@BenjaminBossan
Copy link
Member

Also pinging @muellerzr in case he knows something about this.

@winglian
Copy link
Contributor

winglian commented Sep 5, 2024

I tested in axolotl against latest transformers, and this seems to work with this qlora+peft+zero3 yaml

base_model: NousResearch/Meta-Llama-3-8B

load_in_4bit: true
datasets:
  - path: tatsu-lab/alpaca
    type: alpaca
val_set_size: 0.0
output_dir: ./outputs/lora-out

sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true

adapter: qlora
lora_r: 32
lora_alpha: 64
lora_dropout: 0.05
lora_target_linear: true

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 2
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

bf16: auto
tf32: false

gradient_checkpointing: true
logging_steps: 1
flash_attention: true

warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
deepspeed: deepspeed_configs/zero3_bf16.json
weight_decay: 0.1
special_tokens:
   pad_token: <|end_of_text|>

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@github-actions github-actions bot closed this as completed Oct 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants