diff --git a/python/triton/experimental/gluon/_runtime.py b/python/triton/experimental/gluon/_runtime.py index 5b55d6fa526c..2f303b5f89a3 100644 --- a/python/triton/experimental/gluon/_runtime.py +++ b/python/triton/experimental/gluon/_runtime.py @@ -2,7 +2,7 @@ from triton.compiler.compiler import ASTSource from triton.backends.compiler import Language from triton.runtime.jit import JITFunction, constexpr_function -from typing import TypeVar, Optional, Callable, Iterable, Union +from typing import TypeVar, Optional, Callable, Iterable, Union, overload from triton._C.libtriton import ir T = TypeVar("T") @@ -53,6 +53,25 @@ def is_gluon(self): return True +@overload +def jit(fn: T) -> GluonJITFunction[T]: + ... + + +@overload +def jit( + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int | str]] = None, + do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Callable[[T], GluonJITFunction[T]]: + ... + + def jit( fn: Optional[T] = None, *, diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 61375297464e..065744420421 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -6,7 +6,7 @@ from enum import Enum from functools import partial, wraps, cached_property import typing -from typing import Union, Callable, List, Sequence, TypeVar, Optional, Tuple +from typing import Union, Callable, List, Sequence, TypeVar, Optional, Tuple, TYPE_CHECKING from dataclasses import dataclass import builtins from .. import knobs @@ -1584,6 +1584,15 @@ def _wrap_init_args(x): return constexpr(x) +if TYPE_CHECKING: + from typing_extensions import dataclass_transform +else: + + def dataclass_transform(**kwargs): + return lambda obj: obj + + +@dataclass_transform(eq_default=False) def _aggregate(cls): field_annotations = typing.get_type_hints(cls) field_names = builtins.tuple(field_annotations.keys()) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index dc8eeacf5f55..eaaffdfbd118 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -10,7 +10,7 @@ from collections import defaultdict from dataclasses import dataclass from functools import cached_property -from typing import Callable, Generic, Iterable, Optional, TypeVar, overload, Dict, Any, Tuple +from typing import Callable, Generic, Iterable, Optional, ParamSpec, TypeVar, overload, Dict, Any, Tuple from triton.backends import BaseBackend from types import ModuleType @@ -27,6 +27,8 @@ INDENT_PATTERN = re.compile(r"^(?P[ \t]*)def\s+\w+\s*\(", re.MULTILINE) T = TypeVar("T") +P = ParamSpec("P") +R = TypeVar("R") # ----------------------------------------------------------------------------- # Dependencies Finder @@ -902,7 +904,7 @@ def finalize_compile(kernel): [attrs], warmup) return kernel - def __call__(self, *args, **kwargs): + def __call__(self: "JITFunction[Callable[P, R]]", *args: P.args, **kwargs: P.kwargs) -> R: raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel") def __repr__(self): @@ -1118,7 +1120,7 @@ def __call__(self, *args, **kwargs): return self.__func__(self.__self__, *args, **kwargs) -class ConstexprFunction(JITCallable): +class ConstexprFunction(JITCallable, Generic[T]): def __init__(self, fn): super().__init__(fn) @@ -1129,6 +1131,10 @@ def __get__(self, obj, objclass): return BoundConstexprFunction(obj, self) return self + @overload + def __call__(self: "ConstexprFunction[Callable[P, R]]", *args: P.args, **kwargs: P.kwargs) -> R: + ... + def __call__(self, *args, _semantic=None, **kwargs): from triton.language.core import _unwrap_if_constexpr, constexpr # de-constexpr arguments and discard the _semantic keyword argument: @@ -1148,7 +1154,7 @@ def __call__(self, *args, _semantic=None, **kwargs): return constexpr(res) -def constexpr_function(fn): +def constexpr_function(fn: T) -> ConstexprFunction[T]: """ Wraps an arbitrary Python function so that it can be called at compile-time on constexpr arguments in a Triton function and