diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 0a6394ac1e82..3cb9f9dbe862 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -142,6 +142,7 @@ def warmup(self, *args, **kwargs): class Config: """ An object that represents a possible kernel configuration for the auto-tuner to try. + :ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments. :type meta: dict[Str, Any] :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if @@ -173,8 +174,10 @@ def __str__(self): def autotune(configs, key, prune_configs_by=None, reset_to_zero=None): """ Decorator for auto-tuning a :code:`triton.jit`'d function. + .. highlight:: python .. code-block:: python + @triton.autotune(configs=[ triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), @@ -223,8 +226,10 @@ def heuristics(values): """ Decorator for specifying how the values of certain meta-parameters may be computed. This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. + .. highlight:: python .. code-block:: python + @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) @triton.jit def kernel(x_ptr, x_size, **META):