Skip to content

[Hopper] Update CGA->CTA tiling heuristic and add support for persistent kernels to improve performance of warp specialized kernels with NUM_CTAS>1#2638

Merged
goostavz merged 8 commits into
triton-lang:mainfrom
jsh-20:main
Nov 29, 2023

Conversation

@jsh-20
Copy link
Copy Markdown
Contributor

@jsh-20 jsh-20 commented Nov 10, 2023

  • Improve the logics of determining splitM&splitN
  • Add necessary support for persistent kernels with NUM_CTAS>1
  • Add canonocal warp id query operation to help cse and licm before nvgpu2llvm pass.
  • Add warning info for fallback of warp specialized kernels

@jsh-20 jsh-20 requested a review from ptillet as a code owner November 10, 2023 05:00
Comment thread python/triton/runtime/backends/cuda.c Outdated
@goostavz goostavz changed the title [FRONTEND] Improve perforamnce of NUM_CTAS>1 kernels [Hopper] Improve performance of NUM_CTAS>1 kernels Nov 10, 2023
Comment thread lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp Outdated
Comment thread python/triton/runtime/driver.py
@goostavz goostavz changed the title [Hopper] Improve performance of NUM_CTAS>1 kernels [Hopper] Update the heuristic rule in CGA->CTA tiling for warp specialization Nov 10, 2023
@goostavz goostavz changed the title [Hopper] Update the heuristic rule in CGA->CTA tiling for warp specialization [Hopper] Update the heuristic rule in CGA->CTA tiling Nov 10, 2023
Comment thread lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp Outdated
@jsh-20 jsh-20 changed the title [Hopper] Update the heuristic rule in CGA->CTA tiling [Hopper] Update CGA->CTA tiling heuristic and add support for persistent kernels to improve performance of NUM_CTAS>1 kernels Nov 10, 2023
@jlebar
Copy link
Copy Markdown
Contributor

jlebar commented Nov 10, 2023

Thank you for the PR, and thank you to others for reviewing this change!

Please let us know when this is ready for OpenAI to have a look.

I understand that in the past, OpenAI hasn't had a lot of bandwidth to look at nvidia PRs. It's created a few problems for us, one of which is that nobody over here groks how some of the nvidia-contributed code works! I'm still new, which means I also don't grok Triton yet, but I also have time to do reviews, because I'm new. :) So I'm going to make it a priority of mine to do careful reviews of nvidia code. Please feel free to ping me here or on Slack if you're not getting the attention you need!

For this PR, would you be willing to include the benchmarks you ran? I believe you that it's faster, but I want to be able to run the same tests so I don't accidentally regress you in the future.

@jsh-20
Copy link
Copy Markdown
Contributor Author

jsh-20 commented Nov 13, 2023

https://github.com/openai/triton/pull/2638#issuecomment-1806241110
Hi @jlebar, sorry for late reply. I was struggling to enable related modifcations work on latest codes. Unfortunately, I find many NUM_CTAS>1 cases fail after recent refctoring of pipeline pass and other passes (I think ThomasRaoux is working on repairing it). My local branch is based on commit 2217bd2 (#2531) and it works for me. So now I can only list some results of persistent warp specialized kernel based on this commit, maybe you can cherry-pick modifications of this PR and replay it.
图片
So maybe this feature need to be delayed until we recover most logics of NUM_CTAS>1, @ThomasRaoux pls feel free to let me know if we are ready or there is anything I can help.

@jlebar
Copy link
Copy Markdown
Contributor

jlebar commented Nov 13, 2023

So now I can only list some results of persistent warp specialized kernel based on this commit

No worries on the delay, and thanks for getting back to me!

Are these benchmarks that are checked into the Triton repository? If so, what are the steps to reproduce?

@ThomasRaoux
Copy link
Copy Markdown
Collaborator

So maybe this feature need to be delayed until we recover most logics of NUM_CTAS>1, @ThomasRaoux pls feel free to let me know if we are ready or there is anything I can help.

As discussed on slack I would prefer we first re-enable functionality independently of pipelining before adding back pipeline support. Will you be working on this?

@jsh-20
Copy link
Copy Markdown
Contributor Author

jsh-20 commented Nov 14, 2023

Are these benchmarks that are checked into the Triton repository? If so, what are the steps to reproduce?

Yes, they are checked. you can add the following test config in test_full_static_persistent_matmul_kernel of triton/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py and run in respective version (i.e. before and after applying this PR on #2531). Then I suppose you can reproduce it. BTW, these selected config is from autotune.

[
       #---before---#
       [1024, 128, 64, 4, 8, 1300, 1800, 3000, False, False, 'none', 'float16', True, 5, True],
       [512, 256, 64, 4, 8, 800, 30000, 10000, True, True, 'none', 'float16', True, 4, True],
       [1024, 128, 64, 4, 8, 1800, 10000, 15000, True, True, 'none', 'float16', True, 5, True],
      #  #---after---#
       [512, 256, 64, 4, 8, 1300, 1800, 3000, False, False, 'none', 'float16', True, 5, True],
       [128, 1024, 64, 4, 8, 800, 30000, 10000, True, True, 'none', 'float16', True, 5, True],
       [512, 256, 64, 4, 8, 1800, 10000, 15000, True, True, 'none', 'float16', True, 5, True],
]

@jsh-20
Copy link
Copy Markdown
Contributor Author

jsh-20 commented Nov 14, 2023

As discussed on slack I would prefer we first re-enable functionality independently of pipelining before adding back pipeline support. Will you be working on this?

Previously, I thought only pipeline pass needs to be complemented, however, when I have trouble running persistent warp specialized kernel on the latest codes, I realized other passes were also affected by commits after #2531.

@jsh-20
Copy link
Copy Markdown
Contributor Author

jsh-20 commented Nov 16, 2023

I realized other passes were also affected by commits after #2531.

@ThomasRaoux @jlebar The reason causing many fail cases has been found out: some modfication for fp8 make warp specialized kernel fall back to its normal version, which doesn't support NUM_CTAS>1 after pipeline pass refactoring. @PhrygianGates is helping to modify fp8-related codes to avoid this fallback.

@jsh-20 jsh-20 changed the title [Hopper] Update CGA->CTA tiling heuristic and add support for persistent kernels to improve performance of NUM_CTAS>1 kernels [Hopper] Update CGA->CTA tiling heuristic and add support for persistent kernels to improve performance of warp specialized kernels with NUM_CTAS>1 Nov 16, 2023
Comment thread python/triton/language/core.py
Copy link
Copy Markdown
Collaborator

@goostavz goostavz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@goostavz goostavz merged commit 8729550 into triton-lang:main Nov 29, 2023
feihugis pushed a commit to feihugis/triton that referenced this pull request Feb 13, 2024
…ent kernels to improve performance of warp specialized kernels with NUM_CTAS>1 (triton-lang#2638)

- Improve the logics of determining splitM&splitN
- Add necessary support for persistent kernels with NUM_CTAS>1
- Add canonocal warp id query operation to help cse and licm before
nvgpu2llvm pass.
- Add warning info for fallback of warp specialized kernels
pingzhuu pushed a commit to siliconflow/triton that referenced this pull request Apr 2, 2024
…ent kernels to improve performance of warp specialized kernels with NUM_CTAS>1 (triton-lang#2638)

- Improve the logics of determining splitM&splitN
- Add necessary support for persistent kernels with NUM_CTAS>1
- Add canonocal warp id query operation to help cse and licm before
nvgpu2llvm pass.
- Add warning info for fallback of warp specialized kernels
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants