-
Notifications
You must be signed in to change notification settings - Fork 578
How is "Full finetune 8B on 24GB card" achieved? #2168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Hi @fzyzcjy thanks for creating the issue! To get to this number, we use a few different techniques: full bf16 training, low-precision optimizer (8-bit paged adam from bitsandbytes), activation checkpointing (optionally also offloading to CPU), fusing optimizer step into backward, and chunked cross-entropy loss are the main memory-saving techniques used.
I'm a bit surprised to hear about the performance degradation with bf16, in my experience full bf16 training (with certain numerically sensitive operations, like RoPE or cross-entropy, in fp32) hasn't shown any degradation compared to fp32. But admittedly I run more of the 7-8B models, maybe it is more pronounced on smaller ones. Other than this, low-precision optimizers can potentially reduce model quality, but similarly they're pretty broadly used and I haven't observed much degradation using them myself. Strictly speaking compile will fuse kernels and so numerical equivalence is not guaranteed there either, but I've never seen this mess with my loss curves.
I think the config you linked should be pretty close, then just need to change batch size, seq len, etc to match what's described there. So something like
should get pretty close to a repro here. |
@ebsmothers Hi thank you for the detailed reply! |
For this, it seems we will have super small batch size. It seems the literature I see (AI for mathematics field) usually use a batch size of 32 sequences or more. Thus I wonder whether we will run into troubles? |
@fzyzcjy , you can use set batch_size=1 and gradient_accumulation_steps=32 if you want to (as long as optimizer_in_backward=False). This will run 32 batches before it updates the gradients. Also, you are using dataset.packed=True, you will have multiple samples in the same batch. So what you should be focusing more is number of tokens or samples, rather than batch size. Just remember to increase your LR as you increase number of tokens, so you can converge faster. Usually you would want to do a sweep, and test something like 1e-4, 1e-5, 1e-6, and see what works best for you. Regarding the optimization, you can get an idea of their impact here: https://github.com/pytorch/torchtune?tab=readme-ov-file#optimization-flags |
Thank you! |
Hi thanks for the library! I am interested in the following line in the README, since usually it needs a lot of memory to finetune 8B, instead of merely a 24GB card. Therefore, I would appreciate it if I could know the exact techniques that is used for this, or e.g. a configuration for reproduction.
Also, I would appreciate it if I could know whether it comes with drawbacks. For example, if model is stored in bf16 (instead of fp32), then I have observed a lot of performance degrade for 0.5B ~ 1.5B models, thus maybe not very optimal.
I checked the link https://pytorch.org/torchtune/main/tutorials/memory_optimizations.html, but it is about general optimizations instead of how this line is achieved (e.g. what exact offloads are done). I also checked https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama3_1/8B_full_single_device.yaml, but seems it may not be the exact configuration.
The text was updated successfully, but these errors were encountered: