-
Notifications
You must be signed in to change notification settings - Fork 958
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom primitive + RoPE fat op #676
Conversation
Some benchmarks of the kernel on my M2 air Before
After
The tests fail on float16 and bfloat16 but due to numerical issues. Tomorrow I will do a quick check on the performance if we do all of the computation in float32 in the kernel since it probably doesn't matter at all performance wise. |
if (dims_ != in.shape(-1)) { | ||
throw std::runtime_error("[RoPE] Partial RoPE application not supported"); | ||
} | ||
if (in.flags().row_contiguous && in.is_donatable()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need a contig and copy check before this right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what the copy check is. Also row_contiguous
is stricter than contiguous
is it not? ie all row_contiguous
arrays are contiguous
but not the other way around.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant, if it's not contiguous, we should make a contiguous copy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does not appear to me that your kernel handles non-contiguous inputs, but maybe I missed something..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I think I missed it, I was looking for elem_to_loc, but you hardcoded the strides.. so it should be ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to check here though that the input has the same size as the output? If it's broadcasted e.g. along the last axis it would be incorrect to donate right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I hardcoded the strides cause the grid is launched with half the last dimension and it can't be delegated to a simple elem_to_loc
. I would have to do something like multiply pos.x
by 2 and then pass to elem_to_loc
etc. I think this is equally readable but I am open to suggestions :-)
Regarding broadcasting, a broadcasted array wouldn't be row_contiguous
so this check should be fine donation-wise, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh of course! Let me quietly exit this thread before I say anything else incorrect
Wow, that's so fast! We can also increase the tolerance for the lower precision tests if that's simpler. |
Make this a real PR since I think we are almost done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great. I really like the Custom
primitive.
python/src/extensions.cpp
Outdated
"traditional"_a, | ||
"base"_a, | ||
"scale"_a, | ||
"offset"_a, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think we should make the above keyword only? It would be verbose but error free...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so, yes.
* extensions start * rope custom op * fix build * docs + rope benchmark * fix test * Add a Metal kernel for RoPE * Fix position of traditional * transform tests * Move rope computation to float and fix tests * Fix the test and a typo * change to fast * fix no metal build --------- Co-authored-by: Angelos Katharopoulos <[email protected]>
Proposed changes