Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor compiler specializations to consider backend #4734

Merged
merged 7 commits into from
Oct 2, 2024

Conversation

giuseros
Copy link
Contributor

@giuseros giuseros commented Sep 16, 2024

In this PR I am trying to refactor the specializations that we apply to the signature of a given function in Triton.

Basically, given a kernel there are some argument properties that can help compilation. E.g., divisibility by 16 and the fact that an integer is equal to 1.

In a previous PR: #4716, I needed other specializations to add buffer support in the AMD backend (and get back some performance when we were using unaligned masked loads).

So this is my attempt to redesign the specialization support to introduce per-backend specializations. The idea is that AttrsDescriptor is now the class that is taking care of doing the analysis of the parameters and adding the specialization. It also has a function table where more specializations can be added per-backend.

@giuseros
Copy link
Contributor Author

Hi @ThomasRaoux , @antiagainst , this is my best attempt to redesign the specialization collection. Please, let me know what you think

@antiagainst antiagainst changed the title Refactor compiler specializaitons Refactor compiler specializations to consider backend Sep 18, 2024
Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem to handle all the places where we deal with specialization like compute_spec_key (this one is a bit sensitive as it is used to compute the hash key and therefore is sensitive to execution time).
Also I was hoping for something where we call into the existing backend classes. I haven't looked in details at how easy it would be to do though.

python/triton/tools/compile.py Outdated Show resolved Hide resolved
python/triton/runtime/jit.py Outdated Show resolved Hide resolved
python/triton/compiler/compiler.py Outdated Show resolved Hide resolved
@giuseros
Copy link
Contributor Author

Hi @ThomasRaoux , I tried to apply your suggestions:

  • I moved the class to store the properties of the different parameters into the common backend compiler class
  • I added two methods to the BaseBackend class: get_attrs_descriptor and compute_spec_key. These will be used through out the different code generation phases (i.e., to generate the properties and create a key to retrieve those properties
  • Now the properties are stored not as tuple anymore. The value of the property (16 or 1, in the current state of things) is stored in a different structure.

Please take another look and let me know what you think!

Copy link
Collaborator

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this! A few comments inlined. Bigger questions are:

  1. There are hardcoded logic to check current supported properties in runtime jit. We should think about the limitations on backend specific properties to be clear on what they can/cannot do and document. It's fine to only enable backend specific properties to just influence code generation alone IMO.
  2. I'm wondering whether we can make the AttrsDescriptor better structured with __slots__.

third_party/amd/backend/compiler.py Outdated Show resolved Hide resolved
python/triton/backends/compiler.py Outdated Show resolved Hide resolved
python/triton/backends/compiler.py Outdated Show resolved Hide resolved
python/triton/backends/compiler.py Outdated Show resolved Hide resolved
python/triton/backends/compiler.py Outdated Show resolved Hide resolved
python/triton/backends/compiler.py Show resolved Hide resolved
python/triton/compiler/code_generator.py Outdated Show resolved Hide resolved
python/triton/compiler/code_generator.py Outdated Show resolved Hide resolved
python/triton/backends/compiler.py Outdated Show resolved Hide resolved
python/triton/runtime/jit.py Outdated Show resolved Hide resolved
python/triton/backends/compiler.py Outdated Show resolved Hide resolved
python/triton/backends/compiler.py Show resolved Hide resolved
python/triton/backends/compiler.py Show resolved Hide resolved
python/triton/backends/compiler.py Show resolved Hide resolved
python/triton/compiler/code_generator.py Outdated Show resolved Hide resolved
@giuseros
Copy link
Contributor Author

I had a go trying to reproduce NVIDIA failures without luck: it seems ok. Let's see if it was a glitch, otherwise I will work harder in reproducing the error

Copy link
Collaborator

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully the last batch of comments. :)

python/triton/backends/compiler.py Outdated Show resolved Hide resolved
python/triton/backends/compiler.py Show resolved Hide resolved
python/triton/backends/compiler.py Outdated Show resolved Hide resolved
python/triton/backends/compiler.py Outdated Show resolved Hide resolved
python/triton/backends/compiler.py Show resolved Hide resolved
python/triton/backends/compiler.py Outdated Show resolved Hide resolved
python/triton/backends/compiler.py Outdated Show resolved Hide resolved
@giuseros
Copy link
Contributor Author

Hi @antiagainst , I addressed your comments. PTAL. I still see and NVIDIA failure, but I want to get the code right before diving into that.

Copy link
Collaborator

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, LGTM now! Thanks for bearing with my nitpicking. :) I'll leave the final approve to Thomas though.

python/triton/backends/compiler.py Outdated Show resolved Hide resolved
python/triton/backends/compiler.py Outdated Show resolved Hide resolved
@giuseros
Copy link
Contributor Author

Thank you for your review @antiagainst !

@giuseros
Copy link
Contributor Author

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's fine, just added one comment. I do wonder about the runtime performance impact. I don't see any redflags but I'll check if this is something we have an easy way to benchmark

python/triton/compiler/code_generator.py Outdated Show resolved Hide resolved
@giuseros
Copy link
Contributor Author

giuseros commented Oct 2, 2024

Hi @ThomasRaoux , @antiagainst , any updates on this?

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@antiagainst antiagainst merged commit cd1cc2d into triton-lang:main Oct 2, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants