Skip to content

How is "Full finetune 8B on 24GB card" achieved? #2168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
fzyzcjy opened this issue Dec 18, 2024 · 5 comments
Closed

How is "Full finetune 8B on 24GB card" achieved? #2168

fzyzcjy opened this issue Dec 18, 2024 · 5 comments
Labels
discussion Start a discussion triaged This issue has been assigned an owner and appropriate label

Comments

@fzyzcjy
Copy link

fzyzcjy commented Dec 18, 2024

Hi thanks for the library! I am interested in the following line in the README, since usually it needs a lot of memory to finetune 8B, instead of merely a 24GB card. Therefore, I would appreciate it if I could know the exact techniques that is used for this, or e.g. a configuration for reproduction.

Also, I would appreciate it if I could know whether it comes with drawbacks. For example, if model is stored in bf16 (instead of fp32), then I have observed a lot of performance degrade for 0.5B ~ 1.5B models, thus maybe not very optimal.

I checked the link https://pytorch.org/torchtune/main/tutorials/memory_optimizations.html, but it is about general optimizations instead of how this line is achieved (e.g. what exact offloads are done). I also checked https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama3_1/8B_full_single_device.yaml, but seems it may not be the exact configuration.

image

@ebsmothers
Copy link
Contributor

Hi @fzyzcjy thanks for creating the issue! To get to this number, we use a few different techniques: full bf16 training, low-precision optimizer (8-bit paged adam from bitsandbytes), activation checkpointing (optionally also offloading to CPU), fusing optimizer step into backward, and chunked cross-entropy loss are the main memory-saving techniques used.

Also, I would appreciate it if I could know whether it comes with drawbacks. For example, if model is stored in bf16 (instead of fp32), then I have observed a lot of performance degrade for 0.5B ~ 1.5B models, thus maybe not very optimal.

I'm a bit surprised to hear about the performance degradation with bf16, in my experience full bf16 training (with certain numerically sensitive operations, like RoPE or cross-entropy, in fp32) hasn't shown any degradation compared to fp32. But admittedly I run more of the 7-8B models, maybe it is more pronounced on smaller ones. Other than this, low-precision optimizers can potentially reduce model quality, but similarly they're pretty broadly used and I haven't observed much degradation using them myself. Strictly speaking compile will fuse kernels and so numerical equivalence is not guaranteed there either, but I've never seen this mess with my loss curves.

I checked the link pytorch.org/torchtune/main/tutorials/memory_optimizations.html, but it is about general optimizations instead of how this line is achieved (e.g. what exact offloads are done). I also checked main/recipes/configs/llama3_1/8B_full_single_device.yaml, but seems it may not be the exact configuration.

I think the config you linked should be pretty close, then just need to change batch size, seq len, etc to match what's described there. So something like

tune run full_finetune_single_device --config llama3/8B_full_single_device compile=True batch_size=2 \
dataset.packed=True tokenizer.max_seq_len=2048 enable_activation_offloading=True

should get pretty close to a repro here.

@fzyzcjy
Copy link
Author

fzyzcjy commented Dec 19, 2024

@ebsmothers Hi thank you for the detailed reply!

@fzyzcjy
Copy link
Author

fzyzcjy commented Dec 19, 2024

fusing optimizer step into backward

For this, it seems we will have super small batch size. It seems the literature I see (AI for mathematics field) usually use a batch size of 32 sequences or more. Thus I wonder whether we will run into troubles?

@felipemello1
Copy link
Contributor

felipemello1 commented Dec 21, 2024

@fzyzcjy , you can use set batch_size=1 and gradient_accumulation_steps=32 if you want to (as long as optimizer_in_backward=False). This will run 32 batches before it updates the gradients. Also, you are using dataset.packed=True, you will have multiple samples in the same batch. So what you should be focusing more is number of tokens or samples, rather than batch size. Just remember to increase your LR as you increase number of tokens, so you can converge faster. Usually you would want to do a sweep, and test something like 1e-4, 1e-5, 1e-6, and see what works best for you.

Regarding the optimization, you can get an idea of their impact here: https://github.com/pytorch/torchtune?tab=readme-ov-file#optimization-flags

@felipemello1 felipemello1 added discussion Start a discussion triaged This issue has been assigned an owner and appropriate label labels Dec 21, 2024
@fzyzcjy
Copy link
Author

fzyzcjy commented Dec 21, 2024

Thank you!

@fzyzcjy fzyzcjy closed this as completed Dec 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion Start a discussion triaged This issue has been assigned an owner and appropriate label
Projects
None yet
Development

No branches or pull requests

3 participants