-
Notifications
You must be signed in to change notification settings - Fork 64
Add Grouped GEMM for Mixed Dtype #457
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Grouped GEMM for Mixed Dtype #457
Conversation
examples/sycl/10_bmg_grouped_gemm_mixed_dtype/10_bmg_grouped_gemm_mixed_dtype.cpp
Show resolved
Hide resolved
examples/sycl/10_bmg_grouped_gemm_mixed_dtype/bmg_grouped_gemm_mixed_dtype_runner.hpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
can we refine the include/cutlass/gemm/collective/xe_array_mma_mixed_input.hpp and include/cutlass/gemm/collective/xe_mma_mixed_input.hpp together, i found they are many common code. |
I tried to figure out the difference and dispatch it to xe_mma_mixed_input.hpp. the biggest diff is to initialize the params : array mma must initial the tiled copy with individual tensor and update the group index, so it is not easy |
the quantization and operator(gemm main loop) is same which is the most important part of the implementation, can we make a base struct like xe_mma_mixed_dtype_base contains these common part, and your grouped mixed gemm inherit it? |
True, But this method will involve a lot of files. I will provide another PR to deal with it |
6c8c6a7 to
e9d1004
Compare
This PR adds Grouped GEMM support for mixed precision GEMM.