From 5540d5c25456910f7b65ca40ce648e69507a06d6 Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 15 Apr 2026 16:53:09 -0700 Subject: [PATCH 1/3] typing improvements --- python/triton/experimental/gluon/_runtime.py | 21 +++++++++- python/triton/language/core.py | 10 ++++- python/triton/runtime/jit.py | 42 +++++++++++--------- 3 files changed, 53 insertions(+), 20 deletions(-) diff --git a/python/triton/experimental/gluon/_runtime.py b/python/triton/experimental/gluon/_runtime.py index 5b55d6fa526c..2f303b5f89a3 100644 --- a/python/triton/experimental/gluon/_runtime.py +++ b/python/triton/experimental/gluon/_runtime.py @@ -2,7 +2,7 @@ from triton.compiler.compiler import ASTSource from triton.backends.compiler import Language from triton.runtime.jit import JITFunction, constexpr_function -from typing import TypeVar, Optional, Callable, Iterable, Union +from typing import TypeVar, Optional, Callable, Iterable, Union, overload from triton._C.libtriton import ir T = TypeVar("T") @@ -53,6 +53,25 @@ def is_gluon(self): return True +@overload +def jit(fn: T) -> GluonJITFunction[T]: + ... + + +@overload +def jit( + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int | str]] = None, + do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Callable[[T], GluonJITFunction[T]]: + ... + + def jit( fn: Optional[T] = None, *, diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 61375297464e..3b59a7ee2812 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -6,7 +6,7 @@ from enum import Enum from functools import partial, wraps, cached_property import typing -from typing import Union, Callable, List, Sequence, TypeVar, Optional, Tuple +from typing import Union, Callable, List, Sequence, TypeVar, Optional, Tuple, TYPE_CHECKING from dataclasses import dataclass import builtins from .. import knobs @@ -1584,6 +1584,14 @@ def _wrap_init_args(x): return constexpr(x) +if TYPE_CHECKING: + from typing_extensions import dataclass_transform +else: + def dataclass_transform(**kwargs): + return lambda obj: obj + + +@dataclass_transform(eq_default=False) def _aggregate(cls): field_annotations = typing.get_type_hints(cls) field_names = builtins.tuple(field_annotations.keys()) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index dc8eeacf5f55..75823538e180 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -10,7 +10,7 @@ from collections import defaultdict from dataclasses import dataclass from functools import cached_property -from typing import Callable, Generic, Iterable, Optional, TypeVar, overload, Dict, Any, Tuple +from typing import Callable, Generic, Iterable, Optional, ParamSpec, TypeVar, overload, Dict, Any, Tuple, TYPE_CHECKING from triton.backends import BaseBackend from types import ModuleType @@ -27,6 +27,8 @@ INDENT_PATTERN = re.compile(r"^(?P[ \t]*)def\s+\w+\s*\(", re.MULTILINE) T = TypeVar("T") +P = ParamSpec("P") +R = TypeVar("R") # ----------------------------------------------------------------------------- # Dependencies Finder @@ -902,7 +904,7 @@ def finalize_compile(kernel): [attrs], warmup) return kernel - def __call__(self, *args, **kwargs): + def __call__(self: "JITFunction[Callable[P, R]]", *args: P.args, **kwargs: P.kwargs) -> R: raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel") def __repr__(self): @@ -1118,7 +1120,7 @@ def __call__(self, *args, **kwargs): return self.__func__(self.__self__, *args, **kwargs) -class ConstexprFunction(JITCallable): +class ConstexprFunction(JITCallable, Generic[T]): def __init__(self, fn): super().__init__(fn) @@ -1129,26 +1131,30 @@ def __get__(self, obj, objclass): return BoundConstexprFunction(obj, self) return self - def __call__(self, *args, _semantic=None, **kwargs): - from triton.language.core import _unwrap_if_constexpr, constexpr - # de-constexpr arguments and discard the _semantic keyword argument: - args = [_unwrap_if_constexpr(x) for x in args] - kwargs = {k: _unwrap_if_constexpr(v) for (k, v) in kwargs.items()} + if TYPE_CHECKING: + def __call__(self: "ConstexprFunction[Callable[P, R]]", *args: P.args, **kwargs: P.kwargs) -> R: + ... + else: + def __call__(self, *args, _semantic=None, **kwargs): + from triton.language.core import _unwrap_if_constexpr, constexpr + # de-constexpr arguments and discard the _semantic keyword argument: + args = [_unwrap_if_constexpr(x) for x in args] + kwargs = {k: _unwrap_if_constexpr(v) for (k, v) in kwargs.items()} - # call the raw Python function f: - res = self.fn(*args, **kwargs) + # call the raw Python function f: + res = self.fn(*args, **kwargs) - if _semantic is None: - # Not called by triton code generator, e.g. in host code, another constexpr function, or even an aggreate's __init__ function - return res + if _semantic is None: + # Not called by triton code generator, e.g. in host code, another constexpr function, or even an aggreate's __init__ function + return res - # convert result back to a Triton constexpr: - if knobs.runtime.interpret: - return res # No constexpr in interpreter - return constexpr(res) + # convert result back to a Triton constexpr: + if knobs.runtime.interpret: + return res # No constexpr in interpreter + return constexpr(res) -def constexpr_function(fn): +def constexpr_function(fn: T) -> ConstexprFunction[T]: """ Wraps an arbitrary Python function so that it can be called at compile-time on constexpr arguments in a Triton function and From 35cafc4d765d5e8fb969ee49a40ea45e66bef1b9 Mon Sep 17 00:00:00 2001 From: Mogball Date: Wed, 15 Apr 2026 16:57:03 -0700 Subject: [PATCH 2/3] fmt --- python/triton/language/core.py | 1 + python/triton/runtime/jit.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 3b59a7ee2812..065744420421 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1587,6 +1587,7 @@ def _wrap_init_args(x): if TYPE_CHECKING: from typing_extensions import dataclass_transform else: + def dataclass_transform(**kwargs): return lambda obj: obj diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 75823538e180..412ba649ed29 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -1132,9 +1132,11 @@ def __get__(self, obj, objclass): return self if TYPE_CHECKING: + def __call__(self: "ConstexprFunction[Callable[P, R]]", *args: P.args, **kwargs: P.kwargs) -> R: ... else: + def __call__(self, *args, _semantic=None, **kwargs): from triton.language.core import _unwrap_if_constexpr, constexpr # de-constexpr arguments and discard the _semantic keyword argument: From f6b475e7bca4fce38b01ed9dfa4c1d6f77d07893 Mon Sep 17 00:00:00 2001 From: Mogball Date: Thu, 16 Apr 2026 09:51:10 -0700 Subject: [PATCH 3/3] overload --- python/triton/runtime/jit.py | 38 +++++++++++++++++------------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 412ba649ed29..eaaffdfbd118 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -10,7 +10,7 @@ from collections import defaultdict from dataclasses import dataclass from functools import cached_property -from typing import Callable, Generic, Iterable, Optional, ParamSpec, TypeVar, overload, Dict, Any, Tuple, TYPE_CHECKING +from typing import Callable, Generic, Iterable, Optional, ParamSpec, TypeVar, overload, Dict, Any, Tuple from triton.backends import BaseBackend from types import ModuleType @@ -1131,29 +1131,27 @@ def __get__(self, obj, objclass): return BoundConstexprFunction(obj, self) return self - if TYPE_CHECKING: + @overload + def __call__(self: "ConstexprFunction[Callable[P, R]]", *args: P.args, **kwargs: P.kwargs) -> R: + ... - def __call__(self: "ConstexprFunction[Callable[P, R]]", *args: P.args, **kwargs: P.kwargs) -> R: - ... - else: - - def __call__(self, *args, _semantic=None, **kwargs): - from triton.language.core import _unwrap_if_constexpr, constexpr - # de-constexpr arguments and discard the _semantic keyword argument: - args = [_unwrap_if_constexpr(x) for x in args] - kwargs = {k: _unwrap_if_constexpr(v) for (k, v) in kwargs.items()} + def __call__(self, *args, _semantic=None, **kwargs): + from triton.language.core import _unwrap_if_constexpr, constexpr + # de-constexpr arguments and discard the _semantic keyword argument: + args = [_unwrap_if_constexpr(x) for x in args] + kwargs = {k: _unwrap_if_constexpr(v) for (k, v) in kwargs.items()} - # call the raw Python function f: - res = self.fn(*args, **kwargs) + # call the raw Python function f: + res = self.fn(*args, **kwargs) - if _semantic is None: - # Not called by triton code generator, e.g. in host code, another constexpr function, or even an aggreate's __init__ function - return res + if _semantic is None: + # Not called by triton code generator, e.g. in host code, another constexpr function, or even an aggreate's __init__ function + return res - # convert result back to a Triton constexpr: - if knobs.runtime.interpret: - return res # No constexpr in interpreter - return constexpr(res) + # convert result back to a Triton constexpr: + if knobs.runtime.interpret: + return res # No constexpr in interpreter + return constexpr(res) def constexpr_function(fn: T) -> ConstexprFunction[T]: