-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Description
LongLora is "an efficient fine-tuning approach that extends the context sizes of pre-trained large language models". They propose to fine-tune a model with a sparse local attention while maintaining dense attention during inference. The Shifted-Sparse Attention (S^2-Attn) is depicted in the following (from the paper):
Moreover, the implied modification is relatively simple:
# B: batch size;
# S: sequence length or number of tokens;
# G: group size;
# H: number of attention heads;
# D: dimension of each attention head
# qkv in shape (B, N, 3, H, D), projected queries, keys, and values
# Key line 1: split qkv on H into 2 chunks, and shift G/2 on N
qkv = cat((qkv.chunk(2, 3)[0], qkv.chunk(2, 3)[1].roll(-G/2, 1)), 3).view(B*N/G,G,3,H,D)
# standard self-attention function
out = self_attn(qkv)
# out in shape (B, N, H, D)
# Key line 2: split out on H into 2 chunks, and then roll back G/2 on N
out = cat((out.chunk(2, 2)[0], out.chunk(2, 2)[1].roll(G/2, 1)), 2)
This can be effectively enabled only during the fine-tuning phase while the standard dense attention can be used during inference.
Another thing that should be modified is the padded sequence length, which should be a multiple of the group-size.
If you think that this can be added to lit-gpt, I'm willing to contribute with a PR (I've already something working which I plan to test)
Edit:
I forgot to mention that they also use the Position Interpolation to rescale the position indices. If I'm not mistaken this can be achieved by simply change the rope_condense_ratio
to account for the increased contex-size