Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Llama] Make torchao's Llama trainable #728

Merged
merged 4 commits into from
Aug 22, 2024

Conversation

gau-nernst
Copy link
Collaborator

Fixes #674

To make Llama trainable, I changed the following:

  • Do not initialize KV cache if training=True is passed to setup_caches()
  • When input_pos is not passed to the model, handle it accordingly and use F.sdpa(is_causal=True)

Other minor changes:

  • Ignore .safetensors weights in scripts/download.py to save bandwidth/speed up download time.
  • Use torchao's Llama as an example for benchmarks/quantized_training

I manually ran

python torchao/_models/llama/generate.py --checkpoint_path checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --prompt "Hello, my name is" --compile --precision float16

to confirm that the generated outputs are identical.

Copy link

pytorch-bot bot commented Aug 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/728

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 806721c with merge base 99644e9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 22, 2024
@msaroufim msaroufim merged commit 8002099 into pytorch:main Aug 22, 2024
16 checks passed
@gau-nernst gau-nernst deleted the llama_train branch August 22, 2024 18:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Llama] Make Llama in torchao trainable
3 participants