-
Notifications
You must be signed in to change notification settings - Fork 9.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SYCL] Use batched mul_mat pathway #5591
[SYCL] Use batched mul_mat pathway #5591
Conversation
@NeoZhangJianyu, @abhilash1910, @Alcpz, feedback would be appreciated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I think we can use this until MKL adds the dtypes for batched gemm . Pinging @airMeng @ggerganov for a look when available.
@AidanBeltonS could you please rebase , should fix the android build issue. Thanks
cb21f6c
to
abed262
Compare
* Use batched mul_mat pathway * rm extra line * Explicitly state scaled data type --------- Co-authored-by: Abhilash Majumder <[email protected]>
* Use batched mul_mat pathway * rm extra line * Explicitly state scaled data type --------- Co-authored-by: Abhilash Majumder <[email protected]>
* Use batched mul_mat pathway * rm extra line * Explicitly state scaled data type --------- Co-authored-by: Abhilash Majumder <[email protected]>
This PR enables using the batched mul_mat pathway when appropriate. Previously the single gemm path was being taken and it was not suitable for the type of operation causing segfaults. This PR changes things to more closely match the CUDA impl and use the batched gemm path.
This change allows a lot more tests to pass for SYCL devices. There is one limitation with this approach, we cannot use non default precision operations. As oneMKL has not open sourced the
gemm_batch
for the data types <half, half, float, float> (corresponding to <src0, src1, dst, scaling>) yet. This is something I have raised with oneMKL