Implementation of the proposed Self-Extend in LLM Maybe LongLM: Self-Extend LLM Context Window Without Tuning.
-
[05/31/2024]:🎉 SelfExtend was highlighted in a Google I/O session at YouTube to demonstrate the long-context ability of Gemma!!!
-
[05/01/2024]:🎉 SelfExtend has been accepted by ICML 2024! See you in Vienna!
-
[04/19/2024]:💡 We added the support for LLama-3 with transformers==4.40. To use it with transformers==4.40, you may change the file name of Llama_4_40.py to
Llama.py
to replace the existing patch file.- (Although, Llama-3's offical model hub recommends using transformers==4.40, we find that with transformers==4.38.2, Llama-3 can also work well)
-
[04/06/2024]: We added some hyperparameters searching results with SelfExtend, you may check here
-
[03/24/2024]: We added Triton implemented flash self-extend. Now, you can use our Triton implemented FlashSelfExtend to enjoy self-extend!
-
[03/20/2024]: We do many updates:
- We added the FlashAttention implementation of self-extend, credits to Qingquan Song! This implementation uses the original flash_attn from Tri Dao.
- We also tried to implement FlashAttention for self-extend using Triton. But currently, it only works for the prefill stage, and cannot work well for the decoding stage. We are eagerly debugging this. Any suggestions are very welcome!
- We added the support to Qwen1.5. Check the codes for more details.
- We reorganized this repo and refactored several files. Now, all codes run with transformers==4.38.2 and flash_attn==2.5.6. Legacy codes/examples/README are packed into legacy_patch_before_4_38. We recommend using our docker: hoytjin/selfextend_docker:v0.1 to avoid any environmental issues.
- We updated the api, now you can simply call
SelfExtend.apply(loaded model, group size, window size)
to enjoy our self-extend! Check and run the provided example for more details! - We add a new passkey example with 32k context length and a more challenging 10-digit passkey.
- Please join our Discord for discussion! 🔥🔥
-
[02/22/2024]: We added the Implementation for Google New LLM Gemma!!! Welcome to try and test it out!!
-
[01/19/2024]: We've added the implementation for Llama with transformers 4.36.2 and the implementation for Microsoft's official phi-2 with transformers 4.37. Another good news: the flash attention version will come in days!💥
-
[01/11/2024]: We've tested the implementation for phi-2. It works. You may find some results on this Reddit post and details on this X post
-
[01/08/2024]: Add third-party implementations section
-
[01/07/2024]: Add Implementation for Mistral
-
[01/05/2024]: Our proposed method is discussed on this Reddit post
- Gemma-7b has to be loaded in bfloat16. But Gemma-2b still works well with float16.
- If using transformers 4.36, the default attention used by Llama is
LlamaSpdaAttention
rather thanLlamaSpdaAttention
. Be careful about this and make sure you replace the forward method with the correct class. - Mistral's sliding window should be disabled to use Self-Extend. The reason of why we should not use SWA can be found in our paper.
Llama.cpp https://github.com/ggerganov/llama.cpp
Llama.cpp has a great implementation and integration for self-extend! Have a try! 😄
This work elicits LLMs' inherent ability to handle long contexts without fine-tuning. The limited length of the training sequence during training may limit the application of Large Language Models (LLMs) on long input sequences for inference. In this work, we argue that existing LLMs themselves have inherent capabilities for handling long contexts. Based on this argument, we suggest extending LLMs' context window by themselves to fully utilize their inherent ability. We propose Self-Extend to stimulate LLMs' long context handling potential. The basic idea is to construct bi-level attention information: the group level and the neighbor level. The two levels are computed by the original model's self-attention, which means the proposed does not require any training.
For current Llama Implementation, the python packages used are:
transformers==4.38.2
flash_attn==2.5.6
We recommend to use this docker: hoytjin/selfextend_docker:v0.1
We provided patches for several models before. You may check legacy_patch_before_4_38. It contains legacy patches (llama, mistral, phi..etc) and README.
Clone the repository to your machine and copy your modeling files into the cloned repo directory.
import SelfExtend
# Load your model, e.g., loaded_model = AutoModelForCausalLM.from_pretrained(model_path)
# group size, neighbor window.
SelfExtend.apply(loaded_model, group_size, window_size, enable_flash_attention=False)
# Inference, e.g., loaded_model.generate(...)
enable_flash_attention=False by default, you may set enable_flash_attention=True, if the model is loaed with FlashAttention enabled.
We use passkeyretrieval as an example to show how to use self-extend. You may check example.py:
python example.py
The following thoughts are based on our experience:
-
With Llama-2 as the base model, 2~64 are reasonable for group_size; 512~1536 are feasible for neighbor_window. But larger group_size and smaller neighbor_window are also good in many cases.
-
The general rule of choosing group_size and neighbor_window is: ensure the input sequence lenght is within the maximum extended window size (For llama-2, it would be (4096 - neighbor_window) * group_size + neighbor_window ).
-
We didn't choose the group size carefully. For the same sequence, smaller group should be better. But we found this does not strictly hold in some experiments:
Sometimes, a larger group size can be beneficial. This may be due to the fact that larger positions are not well-trained. A larger group size can utilize smaller positions, which have received more training, to facilitate extension. However, smaller group sizes tend to have better precision. Thus, there is a trade-off. For more details, refer to the ablation study section.
For example:
If the input length for a QA task is 15,800, with a neighbor window set to 1,024, the group size can be set to 5. This is because 5 * (4,096 - 1,024) + 1,024 equals 16,384, which is greater than 15,800. However, setting the group size to 6, or even larger, such as 8 or 16, might improve the model's performance. With a group size of 5, Self-Extend uses positions 1,025 to 3,979 to extend the context window. If the group size is set to 8, Self-Extend uses positions 1,025 to 2,871 for extension. Although a group size of 8 is less precise than a group size of 5, the positions 2,872 to 3,979, utilized by a group size of 5, are less trained during pretraining, which may affect the effectiveness of the extension. -
Maybe, for a sequence of length L, you can try the smallest group size first [calculated by: G * (L- w_n) + w_n] , and then test whether larger group can be better.
Denoting the pretraining context window as
[TLDR]
SelfExtend is not overly sensitive to hyperparameter selection. One could use a representative task to find proper hyperparameters. Or direcly follow our empirical inequality:
If you find our method useful, please kindly cite our paper.
@misc{jin2024llm,
title={LLM Maybe LongLM: Self-Extend LLM Context Window Without Tuning},
author={Hongye Jin and Xiaotian Han and Jingfeng Yang and Zhimeng Jiang and Zirui Liu and Chia-Yuan Chang and Huiyuan Chen and Xia Hu},
year={2024},
eprint={2401.01325},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
We welcome contributions from the research community to improve the effeicency of SelfExtend. If you have any idea or would like to report a bug, please open an issue or submit a pull request.
The code is released under the MIT License.