Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] triton jit of LigerRopeFunction runs every step when variable length sequences are used #146

Closed
tyler-romero opened this issue Aug 28, 2024 · 10 comments · Fixed by #149
Labels
bug Something isn't working p0

Comments

@tyler-romero
Copy link
Collaborator

tyler-romero commented Aug 28, 2024

🐛 Describe the bug

When training a model using LigerKernels, I don't get the speedup I expect. Upon profiling with and without LigerKernels, I see that the slowpoint seems to be the triton JIT of LigerRopeFunction. This seems to happen at every step, instead of only once. Other LigerKernels do not seem to take as long to jit (or the jit just isnt happening after the first step).

Side by Side comparison of rope with and without Liger:
image

LigerRopeFunction in the context of an entire forward pass:
image

I am training with dynamic padding (so different sequence lengths at every forward pass). It seems like this currently the only LigerKernel that is dependent on seq_len in a non-batch dimension (LigerCrossEntropyFunction is dependent on seq_len, but it is pushed into the batch dimension).

Does this intuition about why LigerRopeFunction is JIT every forward pass but LigerCrossEntropyFunction is not make sense?

Reproduce

No response

Versions

> python -m liger_kernel.env_report
Environment Report:
-------------------
Operating System: Linux-6.5.0-44-generic-x86_64-with-glibc2.35
Python version: 3.10.13
PyTorch version: 2.3.0
CUDA version: 12.1
Triton version: 2.3.0
Transformers version: 4.42.3
@ByronHsu
Copy link
Collaborator

triton-lang/triton#3166 seems related to this. probably need triton folks for help @ptillet @Jokeren. Can you also cross post in triton issue?

@ByronHsu
Copy link
Collaborator

cc @yundai424 @lancerts if you have insights

@yundai424
Copy link
Collaborator

my hunch is it has something to do with the fact that we treat sequence length as constexpr in RoPE kernel 🤔 if it has to be known at compile time then makes sense to me that for each sequence length there will be a new function signature / cache entry hmm

@yundai424 yundai424 added bug Something isn't working p0 labels Aug 28, 2024
@ByronHsu
Copy link
Collaborator

ByronHsu commented Aug 28, 2024

see the discussion thread at: https://discord.com/channels/1189498204333543425/1275130785933951039/1278457607291670641. tldr: seq_len should not be constexpr

@yundai424 yundai424 linked a pull request Aug 28, 2024 that will close this issue
3 tasks
@Jokeren
Copy link

Jokeren commented Aug 28, 2024

The analysis makes sense to me. Sequence length shouldn't be a constexpr if you want to get rid of JIT multiple times

@ByronHsu
Copy link
Collaborator

Let's close the ticket once @tyler-romero verify the fix

@tyler-romero
Copy link
Collaborator Author

Look much better!
image

@ByronHsu
Copy link
Collaborator

Thanks folks!

@shreyassks
Copy link

@tyler-romero hey. Could you pls lmk which tool you have used for profiling?

@tyler-romero
Copy link
Collaborator Author

This was referenced Sep 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working p0
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants