diff --git a/python/triton/backends/compiler.py b/python/triton/backends/compiler.py index 49b8c3899a8d..9cb19309c4e9 100644 --- a/python/triton/backends/compiler.py +++ b/python/triton/backends/compiler.py @@ -11,17 +11,18 @@ class AttrsDescriptor: """ - This class handles the compile-time properties for the given function - parameters. Different backends can add more properties to the common ones. - The class contains two fields: + This class handles compile-time properties for specific function parameters. + + Different backends can add more properties to the common ones. The class + contains two fields: `arg_properties`: a dictionary containing the different compile-time properties for different - parameters. I.e., the dictionary will look like: + parameters. I.e., the dictionary is a map from property names to parameter indices { "prop0": (0, 2, 3) "prop1": (0, 4, 5) } - Different backend might need different properties on those paraemters to enable + Different backends might need different properties on those paraemters to enable specific optimizations. The common compile time properties contained in this class are : - "tt.divisibility", i.e., is the given parameter divisible by 16 @@ -34,21 +35,26 @@ class AttrsDescriptor: } """ - __slots__ = ('arg_properties', 'property_val', '__dict__') + __slots__ = ('divisibility', 'equal_to_1', '__dict__') def __init__(self, params=None, values=None): """ + Initialize the compile-time properties + We can initialize the AttrsDescriptor class by passing the list of params of the function and their `values`. The function will try to apply the properties to the values and save the parameters in the `arg_properties` list. If we don't pass either the `params` or the `values` we should initialize the class via an alternative method (see `from_dict` or `from_hints`) """ - self.property_val = {"tt.divisibility": 16, "tt.equal_to_1": 1} + # Default initialization + self.property_values = {"tt.divisibility": 16, "tt.equal_to_1": 1} + self.constant_properties = {"tt.equal_to_1"} self.arg_properties = {} if (params is None) or (values is None): return + # Compile properties deduction assert (len(params) == len(values)) # Divisibility property @@ -64,9 +70,18 @@ def __init__(self, params=None, values=None): if AttrsDescriptor.is_equal_to_1(arg) and not param.do_not_specialize ] - def get_fn_attrs(self): + self.init_slots() + + def init_slots(self): + """ Initialize the slots of this class """ + self.divisibility = self.arg_properties["tt.divisibility"] + self.equal_to_1 = self.arg_properties["tt.equal_to_1"] + + def get_fn_attrs(self) -> Dict: """ - Get the function attributes as a dict like: + Get the function attributes as a dictionary. + + The returned dictionary will look like : { "arg0" : [(prop_name00, val00), (prop_name01, val01), ...)]} "arg1" : [(prop_name10, val10), (prop_name11, val11), ...)]} @@ -74,20 +89,29 @@ def get_fn_attrs(self): """ attrs = {} for prop_name, arg_set in self.arg_properties.items(): - prop_val = self.property_val[prop_name] + prop_val = self.property_values[prop_name] for arg in arg_set: attrs[arg] = attrs.get(arg, []) + [(prop_name, prop_val)] return attrs - def filter_out_property(self, attr_name): - """ Return the same object, without the given attribute `attr_name`""" + def get_constants(self) -> Dict: + """ Return the a dict of constant properties and their values """ + constants = {} + for prop_name in self.constant_properties: + for p in self.arg_properties.get(prop_name, []): + constants[p] = self.property_values[prop_name] + return constants + + def filter_out_constants(self): + """ Return the same object, without properties marked as constants""" import copy c = copy.deepcopy(self) - if attr_name in c.arg_properties: - c.arg_properties.pop(attr_name) + for prop_name in c.constant_properties: + c.arg_properties.pop(prop_name, None) + c.constant_properties = {} return c - def __getitem__(self, attr_name): + def __getitem__(self, attr_name: str): if attr_name in self.arg_properties: return self.arg_properties[attr_name] return [] @@ -102,15 +126,18 @@ def to_dict(self): @staticmethod def from_hints(hints: list[tuple[int, int]]): """ - Create the class from a set of hints that are passed in. So, instead - of deducing the properties from a list of paramaters and values, the user - can pass in a list of `hints=[(param_index, val)]` and if `val` matches - one of the values of the properties (e.g., `prop_val[prop0]`), then we insert - `param_index` into the correct list (e.g., in `arg_properties[prop0]`) + Create the class from a set of hints that are passed in. + + Instead of deducing the properties from a list of paramaters and values, + the user can pass in a list of `hints=[(param_index, val)]` and if `val` + matches one of the values of the properties (e.g., `prop_val[prop0]`), + then we insert `param_index` into the correct list (e.g., in + `arg_properties[prop0]`) """ attrsDescriptor = AttrsDescriptor() - for prop_name, prop_val in attrsDescriptor.property_val.items(): + for prop_name, prop_val in attrsDescriptor.property_values.items(): attrsDescriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val] + attrsDescriptor.init_slots() return attrsDescriptor @staticmethod diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index df05c5079727..afe3b551a474 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -1268,9 +1268,9 @@ def kernel_suffix(signature, specialization): suffix = '' for i, _ in enumerate(signature): suffix += str(i) - if i in specialization["tt.equal_to_1"]: + if i in specialization.equal_to_1: suffix += 'c' - if i in specialization["tt.divisibility"]: + if i in specialization.divisibility: suffix += 'd' return suffix @@ -1284,8 +1284,12 @@ def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map): gscope = fn.__globals__.copy() function_name = fn.repr(specialization) tys = list(specialization.signature.values()) - new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in attrs["tt.equal_to_1"]} - new_attrs = attrs.filter_out_property("tt.equal_to_1") + new_constants = attrs.get_constants() + for k in new_constants: + if k in tys and tys[k] == "i1" and new_constants[k] == 1: + new_constants[k] = True + + new_attrs = attrs.filter_out_constants() fn_attrs = new_attrs.get_fn_attrs() all_constants = constants.copy() all_constants.update(new_constants) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index b06a6f32d9d5..552c1eadf156 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -605,10 +605,11 @@ def run(self, *args, grid, warmup, **kwargs): signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)} configs = (backend.get_attrs_descriptor(self.params, bound_vals), ) + constant_params = configs[0].get_constants() constants = { p.name: v for (v, p) in zip(bound_vals, self.params) - if p.is_constexpr or p.num in configs[0]["tt.equal_to_1"] or v is None + if p.is_constexpr or (p.num in constant_params) or v is None } for i, arg in constants.items(): if callable(arg): diff --git a/python/triton/tools/compile.py b/python/triton/tools/compile.py index fb3df0cea28d..19e0123a9761 100644 --- a/python/triton/tools/compile.py +++ b/python/triton/tools/compile.py @@ -108,8 +108,8 @@ def constexpr(s): for h in hints.values(): assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" attrs = triton.backends.compiler.AttrsDescriptor.from_hints(hints) - for i in attrs["tt.equal_to_1"]: - constants.update({kernel.arg_names[i]: 1}) + for p, v in attrs.get_constants().items(): + constants.update({kernel.arg_names[p]: v}) src = triton.compiler.ASTSource(fn=kernel, constants=constants, signature=signature, attrs=attrs) opts = {"num_warps": args.num_warps, "num_stages": args.num_stages} ccinfo = triton.compile(src, options=opts)