-
Notifications
You must be signed in to change notification settings - Fork 27.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fine-tuning GPT-J-6B in colab: 8-bit weights with low-rank adaptors #14839
Comments
Just read the LORA paper and your implementation combined with weight quantization is very neat, @deniskamazur. Thank you! a few comments:
|
The KoboldAI community is really looking forward to seeing these 8-bit models implemented, since many users of our software use it at their own home computers this allows more people to run 6B at good speeds. If this is not possible i hope it will be easy to detect that a model is the 8-bit variant, so we can avoid executing half() on the model. |
Thank you so much for creating this, Denis!
See the discussion in this issue for more information, but in short, StableEmbedding layer is only required if the model was pretrained with the StableEmbedding layer. In the case of regular finetuning with 8-bit Adam, it is better to have 32-bit optimizers for the embedding layer. It is currently unclear if this is required for LoRA since the frozen 8-bit weights will provide some stability. Just to be sure, it is probably better to optimize the LoRA embedding layer in 32-bit (no change to the model). You can integrate this in your embedding layer class as shown here. Optionally, you can use the elif isinstance(module, FrozenBNBEmbedding):
module.adapter = nn.Sequential(
bnb.nn.StableEmbedding(module.num_embeddings, adapter_dim),
nn.Linear(adapter_dim, module.embedding_dim, bias=False),
)
This is definitely a good idea. From my experience with 8-bit weights is that they work fine as long as they are not optimized over time. So keeping them frozen and optimizing the low-rank matrices should work just fine and produce results similar to the LoRA paper. However, I have never tried the setup of 8-bit weights + 16/32-bit low-rank matrices, so its better to check this. |
|
Hey, everyone! Thanks for your interest and comments!
|
These are orthogonal features so probably they can be implemented separately. Separating these surely would make the PRs simpler to manage. But it'd be good to keep in mind the ensemble from the get going.
Since you're overriding pytorch components, this is already generic enough. So the unique to model changes are the So here we need a sort of a map/policy per arch that will run the right post_init after the map lookup if the model config says so, so .e.g. So this is a hardcoded way (to replace monkeypatch)
and the more generic way which can then be expanded to other archs easily:
which of course should be refactored into a simple:
and since we will likely to have other similar maps as we try to integrate all the new development this then again can be abstracted away:
which will do this and other future feature enabling and not make the code noisy. On the other hand it's possible that my proposal will be supported by others and an explicit code will be used for each class/arch. This is all very incomplete pseudo code, just to show what I'm trying to propose conceptually Here is another example where a policy map is created for different archs: transformers/src/transformers/deepspeed.py Lines 36 to 41 in 10a382b
This is from a very early deepspeed-inference PR #14426
I'm sure others will have a lot more to say, but since you have the code written already, probably the best way is to just open an PR and go from there. You can start with the hardcoded version or you can try to do something like I suggested, which will immediately prepare a foundation to support other architectures. As I said earlier w/o hearing from other maintainers I'm not sure what is the best first step. The lowest risk is hardcoded I'd say. Alternatively you can wait till Monday when many devs should be back and may have a chance to comment. |
I like your suggestion with the policy map. I think I'll wait for the other maintainer's opinions before opening the PR. Thanks! |
I also wonder whether the policy should be arch-specific, or model-specific - what if someone wants to do 8-bit only for FFN or only for Embedding? If model-specific than the specific params to convert to 8-bit can be declared in the model config. or perhaps there could be an arch-specific default and then the model-specific could override it? Not sure... |
From my experience training with 8-bit dynamic block-wise quantization degrades performance over time but it is fine if only applied once and used for inference or as in the case of LoRA as some sort of "base output" that is adapted. As such, I think that LoRA might be required to maintain good performance. That being said, I have never tried finetuning a model and I only worked on pretraining -- it might be that finetuning with 8-bit weights works just fine. I think the solution with a map to specify 8-bit parameters would be very handy. I think that would give the flexibility that is needed. What I would add is what kind of int8 data type is used.
I think it should be model specific. There are certain tradeoffs and important differences having certain things in 8-bit and others in 16-bit for the same model architecture. So it would be very useful to be able to have more flexibility overall to accommodate that. |
Did you mean to say something different here, Tim? Unless I misunderstood, int8 is already a single data type. Perhaps you meant having a flexibility on how many quantization bits are used for different components, so it's not always 8, but can be 4, 16, etc.? Same as
|
Currently, the bnb quantization by default uses dynamic block-wise quantization so the int8 data type represents that data type which is defined by the int8 data + int-to-float map + normalization constants. This data type is storage optimized. Soon, I will also add another data type to bnb which will be compute optimized. It is still represented by int8 data + int-to-float map + normalization constants but these will be different and incompatible from the storage optimized variant. At this point, it is already clear to me that the storage data type can be improved quite easily. So it might also be helpful to support that to make sure future variants can be supported easily. On the other hand, it might be better defined separately. That one defines int8 + a quantization method which is defined somewhere else. |
Sounds good, Tim. So I trust you will come up with the different names then. We just need to think how to make it easily expandable in the future to support other types. My thinking is that perhaps BNB won't be the only library providing quantization support so the more generic it is the better. We can start with one model, flag it experimental until we sort out the config. |
Hi, everyone! Thank you for your suggestions. I'm currently busy with my uni exams, but I'll be back with a PR in a couple of weeks. |
I have a question and I am writing. quantized model(hivemind/gpt-j-6B-8bit) and of the original model(EleutherAI/gpt-j-6B) |
Hi! The inference speed is indeed slower due to the fact that you de-quantize weight matrices for every token. The same is true for training: the fine-tuning speed is not significantly different from the original model because training is parallel over You can combine the two setups (vanilla and 8-bit) to better fit your hardware. |
I just really can not hold back from saying, this is awesome! Thank you 🙏 Good luck on your studies, hope when you're finished I can assist you somehow with next steps. |
Hey, everyone! I've implemented the «hardcoded» version of this issue. You can verify it's functional over here. Should I add any tests before opening a PR? I'd also be glad to implement LoRA and a generalized version of this issue in future PRs. |
Awesome news, @deniskamazur! I won't have time at this moment to support this process very closely but I trust there will be other maintainers who will have a closer look and provide feedback once you open a PR.
Definitely, and you can use this tiny model for functionality tests: https://huggingface.co/hf-internal-testing/tiny-random-gptj As we have a massive test suite it should be relatively easy to build upon/mimic some of the existing tests. And if you get stuck please don't hesitate to ask in the PR. |
Great, thanks! I'll open a PR as soon as I write the test then. |
Hey! I've noticed this PR, that seems to generalize what we are doing with gpt-j-8bit. What should I do with this issue? |
Hi Denis, it has been a long time.... perhaps there has been a misunderstanding - as we have been waiting for you to complete the PR so nothing has happened here until now. Let's tag @younesbelkada, whose PR you linked to. Younes, not to load more work on you, but a quick question - does your PR supercedes Denis' work? or is there some collaboration that can happen here? |
Hi @deniskamazur @stas00 |
I suppose the advantage of loading in int8, is that with fp16 you need 2x memory upfront, but since we now have sharded checkpoints this can be overcome by sharding into smaller shards if someone is really tight on memory, so only the embedding will be the largest param. But otherwise I'll let you guys to discuss the pros and cons of which way, as I'm still busy with the BigScience, but would love to study this closer / support you guys once the marathon is over. May be let's also cc @justheuristic to this discussion. So between the four of you this domain is in the good hands. |
Hey Thanks for your notebook I am trying to run this notebook how ever I am getting the following error when installing bitesandbytes-cuda111 with your specified version 0.26.0: Please let me know if any other version should be replaced. Thanks |
Change
As per the |
@petertjmills Same here. Using int8, the original model fits on an 8 GB NVIDIA GeForce GTX 1080, but crashes after the first generation. The Hivemind model uses float16 or float32 for computation, so it's even more unlikely to succeed. Probably at least 9-10 GB VRAM are needed. |
I am getting the following error when attempting to fine-tune: Traceback (most recent call last): Any idea on how to solve this? Edit: Was able to get the fine-tuning going by modifying the following part:
To:
|
After training the model with this notebook, how can it be saved and loaded back? If I try
I can save the model, but then if I try to load it back in another script with
I get the following warning:
And the loaded model only produces garbage output. Alternatively, if I try to load it with
I get an error:
|
To the best of my knowledge, you will need to manually extract and save model state dict -- containing only the modules you have trained -- and then load the state dict with model.load_state_dict . |
Hi, thanks for your very nice work! I tried to almost blindly copy-past your notebook on a blank colab notebook (simple standard free account). I only encountered one error, almost at the beginning CUDA SETUP: Required library version not found: libsbitsandbytes_cpu.so. Maybe you need to compile it from source? I ignored it (of course I felt that something was not right) and I arrived with no other errors up to the point where it tries to gpt.generate new text, i.e. before fine-tuning. The command gtp.generate is running since 25 minutes. I suspect this slowness is not normal, but rather is an effect of not using gpu. Is that correct? any suggestion how to solve it? Thanks |
Hi @andreo73, You need to install the CUDA version of bitsandbytes, |
Has anyone already tried fine-tuning this with the alpaca approach? |
Runtime error when batchingI'm having issues with getting the proof-of-concept notebook to work with a batch size > 1. The original notebook just iterates over the sample dataset row by row (one example at a time), which works fine also for my dataset. However, when I feed batches of more than one example to the model (in The same happens when I use the My batches are of the form
|
🌟 New model addition
Model description
This is a version of EleutherAI's GPT-J with 6 billion parameters that is modified so you can generate and fine-tune the model in colab or equivalent desktop GPU (e.g. single 1080Ti).
The original GPT-J takes 22+ GB memory for float32 parameters. Even if you cast everything to 16-bit, it will still not fit onto most single-GPU setups short of A6000 and A100. You can inference it on TPU or CPUs, but fine-tuning is way more expensive.
Implementation
Proof-of-concept notebook is available here:
Model card has more detailed explanations and auxiliary notebooks (e.g. model conversion and perplexity check).
The current implementation is somewhat hacky, but it can be integrated easily with modelling_gptj.py if you like the idea.
Open source status
The text was updated successfully, but these errors were encountered: