Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for GPTNeoX models #32

Merged
merged 17 commits into from
Oct 3, 2023
Merged

Conversation

naubull2
Copy link
Contributor

@naubull2 naubull2 commented Oct 3, 2023

Adds Long-LoRA support for GPTNeoX models.

Tested on a colab A100 40GB x 1 instance, with the scripts

  • fine-tune.py
  • supervised-fine-tune.py

Using a sample GPTNeoX model

  • EleutherAI/pythia-1.4b-deduped

As there was no specific guide on how to contribute, I've tried to make as little modification as possible to the original structure.

Added GPTNeoX support by adding a module gptneox_attn_replace just as the original llama_attn_replace.

How to apply

Application is showcased in the tested scripts fine-tune.py, supervised-fine-tune.py

Add

  • model_type argument to switch back and forth between the llama and gpt-neox configuration.
  • import
    from gptneox_attn_replace import replace_gpt_neox_attn
  • Appropriate changes needed for low rank training
     if training_args.low_rank_training:
          if model_args.model_type == "gpt-neox":
              # added `dense` to match with llama as the vanilla peft config would only target 'query_key_value'
              targets = ["query_key_value", "dense"]
          else:
              targets=["q_proj", "k_proj", "v_proj", "o_proj"],
    
          config = LoraConfig(
              r=8,
              lora_alpha=16,
              target_modules=targets,
              lora_dropout=0,
              bias="none",
              task_type="CAUSAL_LM",
          )
          model = get_peft_model(model, config)

Notes on flash-attention + GPTNeoX

  • As the huggingface implementation won't support flash attention off the shelf, I modified some parts from modeling_gpt_neox.py, for the use_flash_attn=True case.
    • transformers == 4.33.3 as of writing.
    • Mainly the part where cached cos/sin rotary embedding is in fp32 where flash-attn requires tensors to be in fp16/bf16 only.
  • Some how the original flash-attention2 interface flash_attn_varlen_func would cause a runtime error of "in-place operation" flash-attention code
    • So I've opted for flash_attn_varlen_qkvpacked_func which worked fine.
      • In changing the dimensions to fit in, I referenced codes by Philipp Schmid🤗 ref

@yukang2017 yukang2017 merged commit 04c8db1 into dvlab-research:main Oct 3, 2023
@yukang2017
Copy link
Member

Hi,

Many thanks for your contribution. These commits are really helpful for this project. I have merged them in to the main branch!

Regards,
Yukang Chen

gianlucamacri pushed a commit to gianlucamacri/LongLoRA that referenced this pull request Oct 31, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants