diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index eaaffdfbd118..182f54faebc2 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -10,7 +10,7 @@ from collections import defaultdict from dataclasses import dataclass from functools import cached_property -from typing import Callable, Generic, Iterable, Optional, ParamSpec, TypeVar, overload, Dict, Any, Tuple +from typing import Callable, Concatenate, Generic, Iterable, Optional, ParamSpec, TYPE_CHECKING, TypeVar, overload, Dict, Any, Tuple from triton.backends import BaseBackend from types import ModuleType @@ -29,6 +29,7 @@ T = TypeVar("T") P = ParamSpec("P") R = TypeVar("R") +U = TypeVar("U") # ----------------------------------------------------------------------------- # Dependencies Finder @@ -907,6 +908,20 @@ def finalize_compile(kernel): def __call__(self: "JITFunction[Callable[P, R]]", *args: P.args, **kwargs: P.kwargs) -> R: raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel") + if TYPE_CHECKING: + + @overload + def __get__(self, instance: None, owner: Optional[type] = None) -> "JITFunction[T]": + ... + + @overload + def __get__(self: "JITFunction[Callable[Concatenate[U, P], R]]", instance: Any, + owner: Optional[type] = None) -> Callable[P, R]: + ... + + def __get__(self, instance, owner=None): + ... + def __repr__(self): return f"JITFunction({self.module}:{self.fn.__qualname__})"