-
Notifications
You must be signed in to change notification settings - Fork 208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Question] triton jit of LigerRopeFunction
runs every step when variable length sequences are used
#146
Comments
triton-lang/triton#3166 seems related to this. probably need triton folks for help @ptillet @Jokeren. Can you also cross post in triton issue? |
cc @yundai424 @lancerts if you have insights |
my hunch is it has something to do with the fact that we treat sequence length as constexpr in RoPE kernel 🤔 if it has to be known at compile time then makes sense to me that for each sequence length there will be a new function signature / cache entry hmm |
see the discussion thread at: https://discord.com/channels/1189498204333543425/1275130785933951039/1278457607291670641. tldr: seq_len should not be constexpr |
The analysis makes sense to me. Sequence length shouldn't be a constexpr if you want to get rid of JIT multiple times |
Let's close the ticket once @tyler-romero verify the fix |
Thanks folks! |
@tyler-romero hey. Could you pls lmk which tool you have used for profiling? |
@shreyassks I just used PyTorch's built-in profiler: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html#using-tracing-functionality |
🐛 Describe the bug
When training a model using LigerKernels, I don't get the speedup I expect. Upon profiling with and without LigerKernels, I see that the slowpoint seems to be the triton JIT of
LigerRopeFunction
. This seems to happen at every step, instead of only once. Other LigerKernels do not seem to take as long to jit (or the jit just isnt happening after the first step).Side by Side comparison of rope with and without Liger:
LigerRopeFunction in the context of an entire forward pass:
I am training with dynamic padding (so different sequence lengths at every forward pass). It seems like this currently the only LigerKernel that is dependent on seq_len in a non-batch dimension (
LigerCrossEntropyFunction
is dependent on seq_len, but it is pushed into the batch dimension).Does this intuition about why
LigerRopeFunction
is JIT every forward pass butLigerCrossEntropyFunction
is not make sense?Reproduce
No response
Versions
The text was updated successfully, but these errors were encountered: