Skip to content

LongLora fine-tuning support #1237

@belerico

Description

@belerico

LongLora is "an efficient fine-tuning approach that extends the context sizes of pre-trained large language models". They propose to fine-tune a model with a sparse local attention while maintaining dense attention during inference. The Shifted-Sparse Attention (S^2-Attn) is depicted in the following (from the paper):

image

Moreover, the implied modification is relatively simple:

# B: batch size; 
# S: sequence length or number of tokens; 
# G: group size;
# H: number of attention heads; 
# D: dimension of each attention head
# qkv in shape (B, N, 3, H, D), projected queries, keys, and values

# Key line 1: split qkv on H into 2 chunks, and shift G/2 on N
qkv = cat((qkv.chunk(2, 3)[0], qkv.chunk(2, 3)[1].roll(-G/2, 1)), 3).view(B*N/G,G,3,H,D)

# standard self-attention function
out = self_attn(qkv)

# out in shape (B, N, H, D)
# Key line 2: split out on H into 2 chunks, and then roll back G/2 on N
out = cat((out.chunk(2, 2)[0], out.chunk(2, 2)[1].roll(G/2, 1)), 2)

This can be effectively enabled only during the fine-tuning phase while the standard dense attention can be used during inference.

Another thing that should be modified is the padded sequence length, which should be a multiple of the group-size.

If you think that this can be added to lit-gpt, I'm willing to contribute with a PR (I've already something working which I plan to test)

Edit:

I forgot to mention that they also use the Position Interpolation to rescale the position indices. If I'm not mistaken this can be achieved by simply change the rope_condense_ratio to account for the increased contex-size

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions