Skip to content

Commit

Permalink
Address review feedback - 3
Browse files Browse the repository at this point in the history
  • Loading branch information
giuseros committed Sep 24, 2024
1 parent e8fbdc0 commit 95f39c9
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 28 deletions.
69 changes: 48 additions & 21 deletions python/triton/backends/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@

class AttrsDescriptor:
"""
This class handles the compile-time properties for the given function
parameters. Different backends can add more properties to the common ones.
The class contains two fields:
This class handles compile-time properties for specific function parameters.
Different backends can add more properties to the common ones. The class
contains two fields:
`arg_properties`: a dictionary containing the different compile-time properties for different
parameters. I.e., the dictionary will look like:
parameters. I.e., the dictionary is a map from property names to parameter indices
{
"prop0": (0, 2, 3)
"prop1": (0, 4, 5)
}
Different backend might need different properties on those paraemters to enable
Different backends might need different properties on those paraemters to enable
specific optimizations. The common compile time properties contained in this class
are :
- "tt.divisibility", i.e., is the given parameter divisible by 16
Expand All @@ -34,21 +35,26 @@ class AttrsDescriptor:
}
"""
__slots__ = ('arg_properties', 'property_val', '__dict__')
__slots__ = ('divisibility', 'equal_to_1', '__dict__')

def __init__(self, params=None, values=None):
"""
Initialize the compile-time properties
We can initialize the AttrsDescriptor class by passing the list of params
of the function and their `values`. The function will try to apply the properties
to the values and save the parameters in the `arg_properties` list. If we don't pass
either the `params` or the `values` we should initialize the class via an alternative method
(see `from_dict` or `from_hints`)
"""
self.property_val = {"tt.divisibility": 16, "tt.equal_to_1": 1}
# Default initialization
self.property_values = {"tt.divisibility": 16, "tt.equal_to_1": 1}
self.constant_properties = {"tt.equal_to_1"}
self.arg_properties = {}
if (params is None) or (values is None):
return

# Compile properties deduction
assert (len(params) == len(values))

# Divisibility property
Expand All @@ -64,30 +70,48 @@ def __init__(self, params=None, values=None):
if AttrsDescriptor.is_equal_to_1(arg) and not param.do_not_specialize
]

def get_fn_attrs(self):
self.init_slots()

def init_slots(self):
""" Initialize the slots of this class """
self.divisibility = self.arg_properties["tt.divisibility"]
self.equal_to_1 = self.arg_properties["tt.equal_to_1"]

def get_fn_attrs(self) -> Dict:
"""
Get the function attributes as a dict like:
Get the function attributes as a dictionary.
The returned dictionary will look like :
{
"arg0" : [(prop_name00, val00), (prop_name01, val01), ...)]}
"arg1" : [(prop_name10, val10), (prop_name11, val11), ...)]}
}
"""
attrs = {}
for prop_name, arg_set in self.arg_properties.items():
prop_val = self.property_val[prop_name]
prop_val = self.property_values[prop_name]
for arg in arg_set:
attrs[arg] = attrs.get(arg, []) + [(prop_name, prop_val)]
return attrs

def filter_out_property(self, attr_name):
""" Return the same object, without the given attribute `attr_name`"""
def get_constants(self) -> Dict:
""" Return the a dict of constant properties and their values """
constants = {}
for prop_name in self.constant_properties:
for p in self.arg_properties.get(prop_name, []):
constants[p] = self.property_values[prop_name]
return constants

def filter_out_constants(self):
""" Return the same object, without properties marked as constants"""
import copy
c = copy.deepcopy(self)
if attr_name in c.arg_properties:
c.arg_properties.pop(attr_name)
for prop_name in c.constant_properties:
c.arg_properties.pop(prop_name, None)
c.constant_properties = {}
return c

def __getitem__(self, attr_name):
def __getitem__(self, attr_name: str):
if attr_name in self.arg_properties:
return self.arg_properties[attr_name]
return []
Expand All @@ -102,15 +126,18 @@ def to_dict(self):
@staticmethod
def from_hints(hints: list[tuple[int, int]]):
"""
Create the class from a set of hints that are passed in. So, instead
of deducing the properties from a list of paramaters and values, the user
can pass in a list of `hints=[(param_index, val)]` and if `val` matches
one of the values of the properties (e.g., `prop_val[prop0]`), then we insert
`param_index` into the correct list (e.g., in `arg_properties[prop0]`)
Create the class from a set of hints that are passed in.
Instead of deducing the properties from a list of paramaters and values,
the user can pass in a list of `hints=[(param_index, val)]` and if `val`
matches one of the values of the properties (e.g., `prop_val[prop0]`),
then we insert `param_index` into the correct list (e.g., in
`arg_properties[prop0]`)
"""
attrsDescriptor = AttrsDescriptor()
for prop_name, prop_val in attrsDescriptor.property_val.items():
for prop_name, prop_val in attrsDescriptor.property_values.items():
attrsDescriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val]
attrsDescriptor.init_slots()
return attrsDescriptor

@staticmethod
Expand Down
12 changes: 8 additions & 4 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,9 +1268,9 @@ def kernel_suffix(signature, specialization):
suffix = ''
for i, _ in enumerate(signature):
suffix += str(i)
if i in specialization["tt.equal_to_1"]:
if i in specialization.equal_to_1:
suffix += 'c'
if i in specialization["tt.divisibility"]:
if i in specialization.divisibility:
suffix += 'd'
return suffix

Expand All @@ -1284,8 +1284,12 @@ def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map):
gscope = fn.__globals__.copy()
function_name = fn.repr(specialization)
tys = list(specialization.signature.values())
new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in attrs["tt.equal_to_1"]}
new_attrs = attrs.filter_out_property("tt.equal_to_1")
new_constants = attrs.get_constants()
for k in new_constants:
if k in tys and tys[k] == "i1" and new_constants[k] == 1:
new_constants[k] = True

new_attrs = attrs.filter_out_constants()
fn_attrs = new_attrs.get_fn_attrs()
all_constants = constants.copy()
all_constants.update(new_constants)
Expand Down
3 changes: 2 additions & 1 deletion python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,10 +605,11 @@ def run(self, *args, grid, warmup, **kwargs):
signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)}

configs = (backend.get_attrs_descriptor(self.params, bound_vals), )
constant_params = configs[0].get_constants()
constants = {
p.name: v
for (v, p) in zip(bound_vals, self.params)
if p.is_constexpr or p.num in configs[0]["tt.equal_to_1"] or v is None
if p.is_constexpr or (p.num in constant_params) or v is None
}
for i, arg in constants.items():
if callable(arg):
Expand Down
4 changes: 2 additions & 2 deletions python/triton/tools/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def constexpr(s):
for h in hints.values():
assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}"
attrs = triton.backends.compiler.AttrsDescriptor.from_hints(hints)
for i in attrs["tt.equal_to_1"]:
constants.update({kernel.arg_names[i]: 1})
for p, v in attrs.get_constants().items():
constants.update({kernel.arg_names[p]: v})
src = triton.compiler.ASTSource(fn=kernel, constants=constants, signature=signature, attrs=attrs)
opts = {"num_warps": args.num_warps, "num_stages": args.num_stages}
ccinfo = triton.compile(src, options=opts)
Expand Down

0 comments on commit 95f39c9

Please sign in to comment.