Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion python/triton/experimental/gluon/_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from triton.compiler.compiler import ASTSource
from triton.backends.compiler import Language
from triton.runtime.jit import JITFunction, constexpr_function
from typing import TypeVar, Optional, Callable, Iterable, Union
from typing import TypeVar, Optional, Callable, Iterable, Union, overload
from triton._C.libtriton import ir

T = TypeVar("T")
Expand Down Expand Up @@ -53,6 +53,25 @@ def is_gluon(self):
return True


@overload
def jit(fn: T) -> GluonJITFunction[T]:
...


@overload
def jit(
*,
version=None,
repr: Optional[Callable] = None,
launch_metadata: Optional[Callable] = None,
do_not_specialize: Optional[Iterable[int | str]] = None,
do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None,
debug: Optional[bool] = None,
noinline: Optional[bool] = None,
) -> Callable[[T], GluonJITFunction[T]]:
...


def jit(
fn: Optional[T] = None,
*,
Expand Down
11 changes: 10 additions & 1 deletion python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from enum import Enum
from functools import partial, wraps, cached_property
import typing
from typing import Union, Callable, List, Sequence, TypeVar, Optional, Tuple
from typing import Union, Callable, List, Sequence, TypeVar, Optional, Tuple, TYPE_CHECKING
from dataclasses import dataclass
import builtins
from .. import knobs
Expand Down Expand Up @@ -1584,6 +1584,15 @@ def _wrap_init_args(x):
return constexpr(x)


if TYPE_CHECKING:
from typing_extensions import dataclass_transform
else:

def dataclass_transform(**kwargs):
return lambda obj: obj


@dataclass_transform(eq_default=False)
def _aggregate(cls):
field_annotations = typing.get_type_hints(cls)
field_names = builtins.tuple(field_annotations.keys())
Expand Down
14 changes: 10 additions & 4 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections import defaultdict
from dataclasses import dataclass
from functools import cached_property
from typing import Callable, Generic, Iterable, Optional, TypeVar, overload, Dict, Any, Tuple
from typing import Callable, Generic, Iterable, Optional, ParamSpec, TypeVar, overload, Dict, Any, Tuple

from triton.backends import BaseBackend
from types import ModuleType
Expand All @@ -27,6 +27,8 @@
INDENT_PATTERN = re.compile(r"^(?P<indent>[ \t]*)def\s+\w+\s*\(", re.MULTILINE)

T = TypeVar("T")
P = ParamSpec("P")
R = TypeVar("R")

# -----------------------------------------------------------------------------
# Dependencies Finder
Expand Down Expand Up @@ -902,7 +904,7 @@ def finalize_compile(kernel):
[attrs], warmup)
return kernel

def __call__(self, *args, **kwargs):
def __call__(self: "JITFunction[Callable[P, R]]", *args: P.args, **kwargs: P.kwargs) -> R:
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")

def __repr__(self):
Expand Down Expand Up @@ -1118,7 +1120,7 @@ def __call__(self, *args, **kwargs):
return self.__func__(self.__self__, *args, **kwargs)


class ConstexprFunction(JITCallable):
class ConstexprFunction(JITCallable, Generic[T]):

def __init__(self, fn):
super().__init__(fn)
Expand All @@ -1129,6 +1131,10 @@ def __get__(self, obj, objclass):
return BoundConstexprFunction(obj, self)
return self

@overload
def __call__(self: "ConstexprFunction[Callable[P, R]]", *args: P.args, **kwargs: P.kwargs) -> R:
...

def __call__(self, *args, _semantic=None, **kwargs):
from triton.language.core import _unwrap_if_constexpr, constexpr
# de-constexpr arguments and discard the _semantic keyword argument:
Expand All @@ -1148,7 +1154,7 @@ def __call__(self, *args, _semantic=None, **kwargs):
return constexpr(res)


def constexpr_function(fn):
def constexpr_function(fn: T) -> ConstexprFunction[T]:
"""
Wraps an arbitrary Python function so that it can be called at
compile-time on constexpr arguments in a Triton function and
Expand Down
Loading