Skip to content

Commit 48d3ada

Browse files
authored
[TIR, TVMScript] Add TIR - Triton integration (#17395)
* [TIR, TVMScript] Add TIR - Triton integration Added a macro `T.call_triton` in TIR script parser, which expands to AOT compilation of the kernel and the host TIR code to launch the kernel.
1 parent 44808b4 commit 48d3ada

File tree

8 files changed

+477
-7
lines changed

8 files changed

+477
-7
lines changed

python/tvm/relax/vm_build.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,13 +243,25 @@ def _vmlink(
243243
if ext_libs is None:
244244
ext_libs = []
245245
lib = None
246+
relax_ext_libs = []
247+
tir_ext_libs = []
246248
if tir_mod is not None and len(tir_mod.get_global_vars()) > 0:
247249
lib = tvm.build(
248250
tir_mod,
249251
target=target,
250252
runtime=_autodetect_system_lib_req(target, system_lib),
251253
)
252-
return Executable(_ffi_api.VMLink(builder, target, lib, ext_libs, params)) # type: ignore
254+
for ext_mod in ext_libs:
255+
if ext_mod.type_key == "cuda":
256+
tir_ext_libs.append(ext_mod)
257+
else:
258+
relax_ext_libs.append(ext_mod)
259+
if lib is not None:
260+
for mod in tir_ext_libs:
261+
lib.import_module(mod)
262+
elif len(tir_ext_libs) > 0:
263+
print("Warning: No TIR module is found, but external modules for TIR are provided.")
264+
return Executable(_ffi_api.VMLink(builder, target, lib, relax_ext_libs, params)) # type: ignore
253265

254266

255267
def build(

python/tvm/script/ir_builder/ir/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
def_function,
2222
ir_module,
2323
module_attrs,
24+
module_get_attr,
25+
module_set_attr,
2426
module_global_infos,
2527
lookup_vdevice,
2628
vdevice,

python/tvm/script/ir_builder/ir/ir.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# under the License.
1717
"""Package tvm.script.ir_builder.ir.ir"""
1818

19-
from typing import Dict, List
19+
from typing import Dict, List, Optional
2020

2121
from tvm.ir import BaseFunc, GlobalVar, GlobalInfo, VDevice, DummyGlobalInfo
2222
from tvm.runtime import Object as tvm_Object
@@ -77,14 +77,66 @@ def def_function(func_name: str, func: BaseFunc) -> None:
7777
return _ffi_api.DefFunction(func_name, func) # type: ignore[attr-defined] # pylint: disable=no-member
7878

7979

80-
def module_attrs(attrs: Dict[str, tvm_Object]) -> None:
80+
def module_attrs(attrs: Dict[str, tvm_Object], allow_overwrite=False) -> None:
8181
"""Specify the attrs of the ir_module frame.
8282
Parameters
8383
----------
8484
attrs: Dict[str, Object]
8585
The module attrs.
86+
allow_overwrite: bool
87+
Whether allow overwrite the existing attrs.
8688
"""
87-
return _ffi_api.ModuleAttrs(attrs) # type: ignore[attr-defined] # pylint: disable=no-member
89+
return _ffi_api.ModuleAttrs(attrs, allow_overwrite) # type: ignore[attr-defined] # pylint: disable=no-member
90+
91+
92+
def current_ir_module() -> IRModuleFrame:
93+
"""Get the current ir_module frame.
94+
Returns
95+
-------
96+
frame: IRModuleFrame
97+
The current frame.
98+
"""
99+
return _ffi_api.CurrentIRModule() # type: ignore[attr-defined] # pylint: disable=no-member
100+
101+
102+
def module_get_attrs() -> Dict[str, tvm_Object]:
103+
"""Get the attrs of the ir_module frame.
104+
Returns
105+
-------
106+
attrs: Dict[str, Object]
107+
The module attrs.
108+
"""
109+
return _ffi_api.ModuleGetAttrs() # type: ignore[attr-defined] # pylint: disable=no-member
110+
111+
112+
def module_get_attr(attr_key: str) -> Optional[tvm_Object]:
113+
"""Get the specified attr of the ir_module frame.
114+
Parameters
115+
----------
116+
attr_key: str
117+
The key of the attr to be retrieved.
118+
Returns
119+
-------
120+
attr: Optional[Object]
121+
The specified module attr or None if not found.
122+
"""
123+
return _ffi_api.ModuleGetAttr(attr_key) # type: ignore[attr-defined] # pylint: disable=no-member
124+
125+
126+
def module_set_attr(
127+
attr_key: str, attr_value: Optional[tvm_Object], allow_overwrite: bool = False
128+
) -> None:
129+
"""Set the specified attr of the ir_module frame.
130+
Parameters
131+
----------
132+
attr_key: str
133+
The key of the attr to be set.
134+
attr_value: Optional[Object]
135+
The value of the attr to be set.
136+
allow_overwrite: bool
137+
Whether allow overwrite the existing attr.
138+
"""
139+
return _ffi_api.ModuleSetAttr(attr_key, attr_value, allow_overwrite) # type: ignore[attr-defined] # pylint: disable=no-member
88140

89141

90142
def module_global_infos(global_infos: Dict[str, List[GlobalInfo]]) -> None:
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""External kernel integration fro TIR"""
18+
import json
19+
import logging
20+
import tempfile
21+
from typing import Any, Dict, List, Tuple, Union
22+
23+
from tvm import __version__ as tvm_version
24+
from tvm import tir
25+
from tvm.runtime import Module, load_module
26+
27+
28+
class BaseKernel:
29+
"""Base class for external kernels."""
30+
31+
def compile_to_device_module(
32+
self, launch_args, *args, **kwargs
33+
) -> Tuple[str, Module, List[Any]]:
34+
"""Compile the kernel to a device module."""
35+
raise NotImplementedError()
36+
37+
def _format_tvm_module_metadata(self, kernel_name, arg_types, launch_param_tags):
38+
"""Format the TVM module metadata."""
39+
tvm_metadata = """{{
40+
"tvm_version": "{version}",
41+
"func_info": {{
42+
"{kernel_name}": {{
43+
"name": "",
44+
"arg_types": {arg_types},
45+
"launch_param_tags": {launch_param_tags}
46+
}}
47+
}}
48+
}}""".format_map(
49+
{
50+
"version": tvm_version,
51+
"kernel_name": kernel_name,
52+
"arg_types": json.dumps(arg_types),
53+
"launch_param_tags": json.dumps(launch_param_tags),
54+
}
55+
)
56+
return tvm_metadata
57+
58+
def _create_cuda_module(self, ptx, kernel_arg_types, launch_param_tags, kernel_name):
59+
"""
60+
Create a CUDA module from PTX and metadata.
61+
62+
Parameters
63+
----------
64+
ptx : str
65+
The PTX code of the kernel.
66+
67+
kernel_arg_types : List[str]
68+
The types of the kernel arguments.
69+
70+
launch_param_tags : List[str]
71+
The tags of the launch parameters.
72+
73+
kernel_name : str
74+
The name of the kernel.
75+
76+
Returns
77+
-------
78+
kernel_module : Module
79+
The CUDA module.
80+
"""
81+
tvm_metadata = self._format_tvm_module_metadata(
82+
kernel_name, kernel_arg_types, launch_param_tags
83+
)
84+
with tempfile.TemporaryDirectory() as temp_dir:
85+
ptx_path = f"{temp_dir}/{kernel_name}.ptx"
86+
with open(ptx_path, "w") as f:
87+
f.write(ptx)
88+
with open(f"{temp_dir}/{kernel_name}.tvm_meta.json", "w") as f:
89+
f.write(tvm_metadata)
90+
kernel_module = load_module(ptx_path)
91+
return kernel_module
92+
93+
94+
def call_kernel(
95+
kernel,
96+
launch_args: List[Union[int, tir.PrimExpr, List[Union[int, tir.PrimExpr]]]],
97+
*args: List[Any],
98+
**kwargs: Dict[str, Any],
99+
):
100+
"""
101+
Call an external kernel.
102+
103+
Parameters
104+
----------
105+
kernel : Any
106+
The external kernel to call.
107+
108+
launch_args : List[Union[int, tir.PrimExpr, List[Union[int, tir.PrimExpr]]]]
109+
The launch arguments. A list of integers for grid size, block size, and shared memory size.
110+
The actual requirements depend on the kernel.
111+
112+
args : List[tir.PrimExpr]
113+
The arguments to pass to the kernel.
114+
115+
kwargs : Dict[str, Any]
116+
Additional keyword arguments to pass to the kernel or compilation.
117+
"""
118+
from ..ir import module_get_attr, module_set_attr # pylint: disable=import-outside-toplevel
119+
from .ir import call_packed # pylint: disable=import-outside-toplevel
120+
121+
kernel_type = f"{type(kernel).__module__}.{type(kernel).__qualname__}"
122+
if kernel_type == "triton.runtime.jit.JITFunction":
123+
from .triton import TritonKernel # pylint: disable=import-outside-toplevel
124+
125+
kernel = TritonKernel(kernel)
126+
else:
127+
raise ValueError("Unsupported kernel type {}".format(kernel_type))
128+
129+
kernel_name, kernel_module, runtime_args = kernel.compile_to_device_module(
130+
launch_args, *args, **kwargs
131+
)
132+
133+
# Attach the kernel module to the current IRModule
134+
external_mods: List[Module] = module_get_attr("external_mods") or []
135+
kernel_exists = any([mod.implements_function(kernel_name) for mod in external_mods])
136+
if kernel_exists:
137+
logging.debug("Kernel %s already exists in the IRModule", kernel_name)
138+
else:
139+
external_mods.append(kernel_module)
140+
module_set_attr("external_mods", external_mods, True)
141+
return call_packed(kernel_name, *runtime_args)

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
from tvm.tir.generic import cast
8484

8585
from . import _ffi_api, frame
86+
from .external_kernel import call_kernel
8687

8788
# pylint: enable=unused-import
8889

@@ -1943,7 +1944,6 @@ def wrapped(*args, **kwargs):
19431944
tvm_call_packed_lowered = call_packed_lowered
19441945
tvm_call_cpacked_lowered = call_cpacked_lowered
19451946

1946-
19471947
# pylint: enable=invalid-name
19481948

19491949

@@ -2255,4 +2255,5 @@ def wrapped(*args, **kwargs):
22552255
"Range",
22562256
"vscale",
22572257
"get_active_lane_mask",
2258+
"call_kernel",
22582259
]
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Triton kernel integration with TIR"""
18+
19+
from typing import Tuple, List, Union, Any, Dict
20+
21+
import triton
22+
from triton.runtime.jit import type_canonicalisation_dict
23+
from tvm import tir
24+
from tvm.topi.utils import get_const_int
25+
from tvm.runtime import Module
26+
from .external_kernel import BaseKernel
27+
28+
29+
class TritonKernel(BaseKernel):
30+
"""A kernel from Triton JIT function.
31+
32+
This class bridges the Triton kernel with TVM runtime. The compilation includes the following
33+
steps:
34+
- Deduce the kernel signature and generate the Triton kernel
35+
- Embed the compiled kernel into the current IRModule as an external module
36+
- Generate a call to the Triton kernel following its calling convention via call_packed.
37+
"""
38+
39+
def __init__(self, func):
40+
self.func = func
41+
42+
def compile_to_device_module(
43+
self,
44+
launch_args: List[Union[int, tir.PrimExpr]],
45+
*args: List[Any],
46+
**kwargs: Dict[str, Any],
47+
) -> Tuple[str, Module, List[Any]]:
48+
"""Compile the kernel to a device module.
49+
50+
Parameters
51+
----------
52+
launch_args : List[int]
53+
The grid size of the kernel. A list of one to three expressions, representing the number
54+
of
55+
"blockIdx.x", "blockIdx.y", and "blockIdx.z" respectively.
56+
57+
args : List[Any]
58+
Arguments to the kernel function.
59+
60+
kwargs : Dict[str, Any]
61+
Additional options for the kernel compilation.
62+
"""
63+
triton_kernel, kernel_args = self._generate_triton_kernel(self.func, *args, **kwargs)
64+
kernel_metadata = triton_kernel.metadata
65+
ptx = triton_kernel.asm["ptx"]
66+
assert kernel_metadata.num_ctas == 1, "Cluster is not supported"
67+
num_warps = kernel_metadata.num_warps
68+
grid = launch_args
69+
launch_param_tags = ["threadIdx.x"] + ["blockIdx.x", "blockIdx.y", "blockIdx.z"][
70+
: len(grid)
71+
]
72+
launch_args = [num_warps * 32] + list(grid)
73+
kernel_arg_types = [arg.dtype for arg in kernel_args]
74+
if triton_kernel.metadata.shared > 0:
75+
# Add shared memory size to the launch arguments
76+
launch_param_tags.append("tir.use_dyn_shared_memory")
77+
launch_args.append(triton_kernel.metadata.shared)
78+
79+
kernel_module = self._create_cuda_module(
80+
ptx, kernel_arg_types, launch_param_tags, triton_kernel.name
81+
)
82+
83+
return triton_kernel.name, kernel_module, kernel_args + launch_args
84+
85+
def _generate_triton_kernel(
86+
self, func, *args, **kwargs
87+
) -> Tuple["triton.compiler.CompiledKernel", List[tir.PrimExpr]]:
88+
"""Deduce the kernel signature and generate the Triton kernel"""
89+
90+
kernel_params = func.params
91+
assert len(kernel_params) == len(
92+
args
93+
), f"Number of arguments does not match, expected {len(kernel_params)}, got {len(args)}"
94+
95+
signature = {}
96+
constants = {}
97+
kernel_args = [] # Arguments to invoke the kernel
98+
for i, arg in enumerate(args):
99+
if kernel_params[i].is_constexpr:
100+
constants[kernel_params[i].name] = get_const_int(arg)
101+
continue
102+
if arg.dtype == "handle":
103+
assert isinstance(arg, tir.Var)
104+
elem_type = arg.type_annotation.element_type.dtype
105+
pointer_type = "*" + type_canonicalisation_dict[elem_type]
106+
signature[kernel_params[i].name] = pointer_type
107+
else:
108+
signature[kernel_params[i].name] = type_canonicalisation_dict[arg.dtype]
109+
kernel_args.append(arg)
110+
111+
# TODO: Support default argument in the kernel
112+
# TODO: Add specialization for aligned buffer pointers
113+
source = triton.compiler.ASTSource(fn=func, constants=constants, signature=signature)
114+
compiled = triton.compiler.compile(source, options=kwargs)
115+
return compiled, kernel_args

0 commit comments

Comments
 (0)