[Hopper] Update CGA->CTA tiling heuristic and add support for persistent kernels to improve performance of warp specialized kernels with NUM_CTAS>1#2638
Conversation
|
Thank you for the PR, and thank you to others for reviewing this change! Please let us know when this is ready for OpenAI to have a look. I understand that in the past, OpenAI hasn't had a lot of bandwidth to look at nvidia PRs. It's created a few problems for us, one of which is that nobody over here groks how some of the nvidia-contributed code works! I'm still new, which means I also don't grok Triton yet, but I also have time to do reviews, because I'm new. :) So I'm going to make it a priority of mine to do careful reviews of nvidia code. Please feel free to ping me here or on Slack if you're not getting the attention you need! For this PR, would you be willing to include the benchmarks you ran? I believe you that it's faster, but I want to be able to run the same tests so I don't accidentally regress you in the future. |
|
https://github.com/openai/triton/pull/2638#issuecomment-1806241110 |
No worries on the delay, and thanks for getting back to me! Are these benchmarks that are checked into the Triton repository? If so, what are the steps to reproduce? |
As discussed on slack I would prefer we first re-enable functionality independently of pipelining before adding back pipeline support. Will you be working on this? |
Yes, they are checked. you can add the following test config in |
Previously, I thought only pipeline pass needs to be complemented, however, when I have trouble running persistent warp specialized kernel on the latest codes, I realized other passes were also affected by commits after #2531. |
@ThomasRaoux @jlebar The reason causing many fail cases has been found out: some modfication for fp8 make warp specialized kernel fall back to its normal version, which doesn't support NUM_CTAS>1 after pipeline pass refactoring. @PhrygianGates is helping to modify fp8-related codes to avoid this fallback. |
…euristic to improve performance of NUM_CTAS>1 persistent warp specialized kernels
…ent kernels to improve performance of warp specialized kernels with NUM_CTAS>1 (triton-lang#2638) - Improve the logics of determining splitM&splitN - Add necessary support for persistent kernels with NUM_CTAS>1 - Add canonocal warp id query operation to help cse and licm before nvgpu2llvm pass. - Add warning info for fallback of warp specialized kernels
…ent kernels to improve performance of warp specialized kernels with NUM_CTAS>1 (triton-lang#2638) - Improve the logics of determining splitM&splitN - Add necessary support for persistent kernels with NUM_CTAS>1 - Add canonocal warp id query operation to help cse and licm before nvgpu2llvm pass. - Add warning info for fallback of warp specialized kernels

Uh oh!
There was an error while loading. Please reload this page.