Skip to content

Commit a9c5527

Browse files
jamesjwupytorchmergebot
authored andcommitted
[Reland] First version of statically compiled launcher for triton compiled CUDA kernels (pytorch#149238)
This is a new version of pytorch#148561 fixing the ROCM test failure Putting this up for a first pass review, though I will likely make a bunch of changes before landing to add more features, etc. This diff implements a first version of a static CUDA kernel launcher in `torch._C`. The goal here is to take a cubin file and some metadata from a CompiledKernel from `triton`, and launch the cubin file directly. Background doc: https://docs.google.com/document/d/1rjRcHl6MfauHG30nCoQX-9UKvKyIs4WWMy_GsGyqb9g/edit?tab=t.0#heading=h.ut5lf39lzq66 Normally, using triton's CompiledKernel.make_launcher(), we would pay the cost of codegenning C++ and running it at compile time. With this new approach, we can use one statically compiled library to launch the kernel. The tradeoff here is that this new kernel launcher will not be able to use codegen to deal with different lengths/types of arguments. So we use templating to handle up to 10 arguments for now. We also allocate 8 bytes on the stack per argument no matter the argument type, which can take more memory than codegenning. On the other hand, we improve compile time on cold and warm start by not having to call the C++ compiler at all. This diff does not add the launcher to torch, but introduces a basic test suite. A list of TODOs that are not yet complete: - Handle `nvTmaDesc` and `cuTensorMap`, which triton handles - Embed the grid logic instead of passing in gridX,Y,Z - Handle launch_enter and exit hooks? (Not sure if inductor has these) - Benchmarking to see if there's runtime performance loss - Probably lots of features of the triton C++ generated code that I haven't handled yet. Pull Request resolved: pytorch#149238 Approved by: https://github.com/oulgen
1 parent c83c711 commit a9c5527

File tree

9 files changed

+1004
-0
lines changed

9 files changed

+1004
-0
lines changed

aten/src/ATen/cuda/detail/LazyNVRTC.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ NVRTC_STUB2(nvrtcGetProgramLogSize,nvrtcProgram, size_t*)
156156
NVRTC_STUB2(nvrtcGetProgramLog, nvrtcProgram, char *)
157157
NVRTC_STUB3(nvrtcGetLoweredName, nvrtcProgram, const char *, const char **)
158158

159+
CUDA_STUB2(cuModuleLoad, CUmodule*, const char*)
159160
CUDA_STUB2(cuModuleLoadData, CUmodule *, const void *)
160161
CUDA_STUB3(cuModuleGetFunction, CUfunction *, CUmodule, const char *)
161162
CUDA_STUB4(cuOccupancyMaxActiveBlocksPerMultiprocessor, int *, CUfunction, int, size_t)
@@ -169,6 +170,8 @@ CUDA_STUB4(cuLinkCreate, unsigned int, CUjit_option *, void **, CUlinkState *)
169170
CUDA_STUB3(cuLinkComplete, CUlinkState, void **, size_t *)
170171
CUDA_STUB3(cuFuncSetAttribute, CUfunction, CUfunction_attribute, int)
171172
CUDA_STUB3(cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction)
173+
CUDA_STUB3(cuPointerGetAttribute, void*, CUpointer_attribute, CUdeviceptr)
174+
172175

173176
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000
174177
CUresult CUDAAPI

aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h

+2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ namespace at::cuda {
4343
_(nvrtcGetProgramLogSize) \
4444
_(nvrtcGetProgramLog) \
4545
_(nvrtcGetLoweredName) \
46+
_(cuModuleLoad) \
4647
_(cuModuleLoadData) \
4748
_(cuModuleLoadDataEx) \
4849
_(cuModuleGetFunction) \
@@ -60,6 +61,7 @@ namespace at::cuda {
6061
_(cuLinkComplete) \
6162
_(cuFuncSetAttribute) \
6263
_(cuFuncGetAttribute) \
64+
_(cuPointerGetAttribute) \
6365

6466
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000
6567
#define AT_FORALL_NVRTC_EXTENDED(_) \

build_variables.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,7 @@ libtorch_python_core_sources = [
859859
"torch/csrc/inductor/aoti_eager/kernel_holder.cpp",
860860
"torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp",
861861
"torch/csrc/inductor/resize_storage_bytes.cpp",
862+
"torch/csrc/inductor/static_cuda_launcher.cpp",
862863
"torch/csrc/jit/backends/backend_init.cpp",
863864
"torch/csrc/jit/python/init.cpp",
864865
"torch/csrc/jit/passes/onnx.cpp",
+319
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
# Owner(s): ["module: inductor"]
2+
import os
3+
import tempfile
4+
from typing import Any, Callable
5+
6+
import torch
7+
from torch._dynamo.device_interface import get_interface_for_device
8+
from torch._inductor.runtime import triton_helpers
9+
from torch._inductor.runtime.static_cuda_launcher import StaticallyLaunchedCudaKernel
10+
from torch._inductor.runtime.triton_compat import tl, triton
11+
from torch._inductor.runtime.triton_helpers import libdevice
12+
from torch._inductor.test_case import TestCase
13+
from torch.testing._internal.common_utils import skipIfRocm
14+
from torch.testing._internal.triton_utils import requires_cuda
15+
16+
17+
@requires_cuda
18+
class TestStaticCudaLauncher(TestCase):
19+
def setUp(self):
20+
# Create a temporary file to store the cubin.
21+
# We set delete=False so that the file persists after closing.
22+
self.tmp_file = tempfile.NamedTemporaryFile(mode="wb")
23+
self.tmp_file.close() # Close now; we'll open it for writing later.
24+
super().setUp()
25+
26+
def tearDown(self):
27+
super().tearDown()
28+
# Delete the temporary cubin file.
29+
try:
30+
os.remove(self.tmp_file.name)
31+
except FileNotFoundError:
32+
pass
33+
34+
def _make_launcher(
35+
self,
36+
kernel: Callable,
37+
args: tuple[Any, ...],
38+
grid: tuple[Any, ...] = (1,),
39+
) -> StaticallyLaunchedCudaKernel:
40+
"""
41+
Compiles a Triton kernel with the provided *args,
42+
writes its cubin to the temporary file, and returns the file path.
43+
"""
44+
fn = triton.jit(kernel)
45+
# Launch the kernel to trigger compilation.
46+
compiled_kernel = fn[grid](*args)
47+
result = StaticallyLaunchedCudaKernel(compiled_kernel)
48+
result.write_cubin_to_file(self.tmp_file.name)
49+
result.load_kernel()
50+
return result
51+
52+
@skipIfRocm
53+
def test_basic(self):
54+
def simple_kernel(arg0, arg1):
55+
x = tl.load(arg0)
56+
y = arg1
57+
tl.store(arg0, x + y)
58+
59+
arg0 = torch.zeros(1, dtype=torch.int32, device="cuda")
60+
arg1 = 5
61+
args = (arg0, arg1)
62+
63+
launcher = self._make_launcher(simple_kernel, args, (1,))
64+
self.assertEqual(arg0, torch.tensor([5], dtype=torch.int32, device="cuda"))
65+
self.assertEqual(launcher.arg_tys, "Oi")
66+
new_arg0 = torch.zeros(1, dtype=torch.int32, device="cuda")
67+
device_interface = get_interface_for_device("cuda")
68+
stream = device_interface.get_raw_stream(device_interface.current_device())
69+
70+
launcher.run((1,), stream, new_arg0, arg1)
71+
self.assertEqual(new_arg0, arg0)
72+
73+
# I wish I could macro all int types this into a single unit test on a loop, but
74+
# 1. variables aren't allowed as type annotations in python
75+
# 2. triton relies on inspect.get_source to get the type annotations
76+
# so I can't even use exec() to generate the test cases.
77+
# So we'll just make a few kernels by hand
78+
@skipIfRocm
79+
def test_unsigned_integers(self):
80+
def unsigned_integers(
81+
arg0, arg1: tl.uint8, arg2: tl.uint16, arg3: tl.uint32, arg4: tl.uint64
82+
):
83+
x = tl.load(arg0)
84+
y = arg1 + arg2 + arg3 + arg4
85+
tl.store(arg0, x + y)
86+
87+
arg0 = torch.zeros(1, dtype=torch.uint64, device="cuda")
88+
# Using small numbers creates a Literal type which triton treats as a constant
89+
args = (arg0, 50, 50, 50, 50)
90+
91+
launcher = self._make_launcher(unsigned_integers, args, (1,))
92+
self.assertEqual(arg0, torch.tensor([200], dtype=torch.uint64, device="cuda"))
93+
self.assertEqual(launcher.arg_tys, "OBHIK")
94+
new_arg0 = torch.zeros(1, dtype=torch.uint64, device="cuda")
95+
device_interface = get_interface_for_device("cuda")
96+
stream = device_interface.get_raw_stream(device_interface.current_device())
97+
launcher.run((1,), stream, new_arg0, 50, 50, 50, 50)
98+
self.assertEqual(new_arg0, arg0)
99+
100+
@skipIfRocm
101+
def test_signed_integers(self):
102+
def signed_integers(
103+
arg0, arg1: tl.int8, arg2: tl.int16, arg3: tl.int32, arg4: tl.int64
104+
):
105+
x = tl.load(arg0)
106+
y = arg1 + arg2 + arg3 + arg4
107+
tl.store(arg0, x + y)
108+
109+
arg0 = torch.zeros(1, dtype=torch.int64, device="cuda")
110+
# Using small numbers creates a Literal type which triton treats as a constant
111+
args = (arg0, 50, 50, 50, 50)
112+
113+
launcher = self._make_launcher(signed_integers, args, (1,))
114+
self.assertEqual(arg0, torch.tensor([200], dtype=torch.int64, device="cuda"))
115+
self.assertEqual(launcher.arg_tys, "Obhil")
116+
new_arg0 = torch.zeros(1, dtype=torch.int64, device="cuda")
117+
device_interface = get_interface_for_device("cuda")
118+
stream = device_interface.get_raw_stream(device_interface.current_device())
119+
launcher.run((1,), stream, new_arg0, 50, 50, 50, 50)
120+
self.assertEqual(new_arg0, arg0)
121+
122+
# TODO: floats don't work properly, triton seems to think they're all tl.float32
123+
# despite type annotations.
124+
# There's also not really a good way for me to make a float16 in python...
125+
@skipIfRocm
126+
def test_floats(self):
127+
def floats(arg0, arg1: tl.float16, arg2: tl.float32, arg3: tl.float64):
128+
x = tl.load(arg0)
129+
y = arg1 + arg2 + arg3
130+
tl.store(arg0, x + y)
131+
132+
arg0 = torch.zeros(1, dtype=torch.float64, device="cuda")
133+
134+
args = (arg0, 1.0, 1.0, 1.0)
135+
136+
launcher = self._make_launcher(floats, args, (1,))
137+
# TODO: in Pytorch's pinned version of triton, arg3 is typed as regular float
138+
# but in triton 3.3.0, this is fixed and it's 0ffd. We'll need to update later.
139+
self.assertEqual(launcher.arg_tys, "Offf")
140+
self.assertEqual(arg0, torch.tensor([3.0], dtype=torch.float64, device="cuda"))
141+
new_arg0 = torch.zeros(1, dtype=torch.float64, device="cuda")
142+
device_interface = get_interface_for_device("cuda")
143+
stream = device_interface.get_raw_stream(device_interface.current_device())
144+
launcher.run((1,), stream, new_arg0, 1.0, 1.0, 1.0)
145+
self.assertEqual(new_arg0, arg0)
146+
147+
@skipIfRocm
148+
def test_basic_1arg(self):
149+
def simple_kernel_1_arg(arg0):
150+
x = tl.load(arg0)
151+
tl.store(arg0, x + 1)
152+
153+
arg0 = torch.zeros(1, dtype=torch.int32, device="cuda")
154+
launcher = self._make_launcher(simple_kernel_1_arg, (arg0,), (1,))
155+
self.assertEqual(arg0, torch.tensor([1], dtype=torch.int32, device="cuda"))
156+
self.assertEqual(launcher.arg_tys, "O")
157+
new_arg0 = torch.zeros(1, dtype=torch.int32, device="cuda")
158+
device_interface = get_interface_for_device("cuda")
159+
stream = device_interface.get_raw_stream(device_interface.current_device())
160+
161+
launcher.run(
162+
(1,),
163+
stream,
164+
new_arg0,
165+
)
166+
self.assertEqual(new_arg0, arg0)
167+
168+
@skipIfRocm
169+
def test_constexpr(self):
170+
# Constexprs are compiled directly into the cubin file,
171+
# so we never need to pass it to StaticCudaLauncher.
172+
173+
@triton.jit
174+
def kernel_constexpr(arg0, CONSTANT: tl.constexpr):
175+
x = tl.load(arg0)
176+
tl.store(arg0, x + CONSTANT)
177+
178+
# Can't use make_launcher because constexpr needs to be constant
179+
arg0 = torch.zeros(1, dtype=torch.int32, device="cuda")
180+
compiled_kernel = kernel_constexpr[(1,)](arg0, CONSTANT=5)
181+
launcher = StaticallyLaunchedCudaKernel(compiled_kernel)
182+
launcher.write_cubin_to_file(self.tmp_file.name)
183+
launcher.load_kernel()
184+
185+
self.assertEqual(arg0, torch.tensor([5], dtype=torch.int32, device="cuda"))
186+
self.assertEqual(launcher.arg_tys, "O")
187+
new_arg0 = torch.zeros(1, dtype=torch.int32, device="cuda")
188+
device_interface = get_interface_for_device("cuda")
189+
stream = device_interface.get_raw_stream(device_interface.current_device())
190+
launcher.run(
191+
(1,),
192+
stream,
193+
new_arg0,
194+
)
195+
self.assertEqual(new_arg0, arg0)
196+
197+
@skipIfRocm
198+
def test_implied_constant(self):
199+
"""xnumel is unused in this kernel, but isn't explicitly marked as a constexpr"""
200+
201+
# This kernel was generated by inductor so it has a bunch of unused arguments. We don't change it
202+
@triton.jit
203+
def triton_red_fused_any_isinf_0(
204+
in_ptr0,
205+
out_ptr0,
206+
xnumel, # noqa: F841
207+
r0_numel,
208+
XBLOCK: tl.constexpr,
209+
R0_BLOCK: tl.constexpr,
210+
):
211+
xnumel = 1 # noqa: F841
212+
rnumel = r0_numel # noqa: F841
213+
RBLOCK: tl.constexpr = R0_BLOCK # noqa: F841
214+
xoffset = tl.program_id(0) * XBLOCK
215+
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] # noqa: F841
216+
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) # noqa: F841
217+
r0_base = tl.arange(0, R0_BLOCK)[None, :]
218+
rbase = r0_base # noqa: F841
219+
_tmp3 = tl.full([XBLOCK, R0_BLOCK], False, tl.int1)
220+
for r0_offset in range(0, r0_numel, R0_BLOCK):
221+
r0_index = r0_offset + r0_base
222+
r0_mask = r0_index < r0_numel
223+
roffset = r0_offset # noqa: F841
224+
rindex = r0_index # noqa: F841
225+
r0_0 = r0_index
226+
tmp0 = tl.load(
227+
in_ptr0 + (r0_0), r0_mask, eviction_policy="evict_first", other=0.0
228+
)
229+
tmp1 = libdevice.isinf(tmp0).to(tl.int1)
230+
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
231+
tmp4 = _tmp3 | tmp2
232+
_tmp3 = tl.where(r0_mask, tmp4, _tmp3)
233+
tmp3 = triton_helpers.any(_tmp3.to(tl.int8), 1)[:, None].to(tl.int1)
234+
tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp3, None)
235+
236+
arg0 = torch.tensor([0.0, 0.5, float("inf"), 5], device="cuda")
237+
arg1 = torch.tensor([False], device="cuda")
238+
arg2 = torch.tensor([False], device="cuda")
239+
compiled_kernel = triton_red_fused_any_isinf_0[1,](
240+
arg0, arg1, 1, 128, XBLOCK=1, R0_BLOCK=1
241+
)
242+
243+
launcher = StaticallyLaunchedCudaKernel(compiled_kernel)
244+
launcher.write_cubin_to_file(self.tmp_file.name)
245+
launcher.load_kernel()
246+
247+
device_interface = get_interface_for_device("cuda")
248+
stream = device_interface.get_raw_stream(device_interface.current_device())
249+
launcher.run((1,), stream, arg0, arg2, 1, 128)
250+
self.assertEqual(arg1, arg2)
251+
252+
@skipIfRocm
253+
def test_kernel_empty_tensor(self):
254+
# Triton kernel generated by torch.compile of the following:
255+
# @torch.compile()
256+
# def foo(x, y):
257+
# return torch.cat(((x * 4), y + 10))
258+
259+
# Running with example input:
260+
# torch._dynamo.decorators.mark_unbacked(t, 0)
261+
# x = torch.rand(0, device="cuda")
262+
# y = torch.rand(20, device="cuda")
263+
264+
@triton.jit
265+
def triton_poi_fused_cat_0(
266+
in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK: tl.constexpr
267+
):
268+
xoffset = tl.program_id(0).to(tl.int64) * XBLOCK
269+
xindex = xoffset + tl.arange(0, XBLOCK)[:].to(tl.int64)
270+
xmask = xindex < xnumel
271+
x0 = xindex
272+
tmp0 = x0
273+
tmp3 = ks0
274+
tmp4 = tmp0 < tmp3
275+
tmp5 = tl.load(
276+
in_ptr0 + (x0), xmask & tmp4, eviction_policy="evict_last", other=0.0
277+
)
278+
tmp6 = 4.0
279+
tmp7 = tmp5 * tmp6
280+
tmp8 = tl.full(tmp7.shape, 0.0, tmp7.dtype)
281+
tmp9 = tl.where(tmp4, tmp7, tmp8)
282+
tmp10 = tmp0 >= tmp3
283+
tmp13 = tl.load(
284+
in_ptr1 + (x0 + ((-1) * ks0)),
285+
xmask & tmp10,
286+
eviction_policy="evict_last",
287+
other=0.0,
288+
)
289+
tmp14 = 10.0
290+
tmp15 = tmp13 + tmp14
291+
tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype)
292+
tmp17 = tl.where(tmp10, tmp15, tmp16)
293+
tmp18 = tl.where(tmp4, tmp9, tmp17)
294+
tl.store(out_ptr0 + (x0), tmp18, xmask)
295+
296+
arg0 = 0
297+
arg1 = torch.randn(0, device="cuda")
298+
arg2 = torch.randn(20, device="cuda")
299+
buf0 = torch.empty(20, device="cuda")
300+
buf1 = torch.empty(20, device="cuda")
301+
xnumel = 20 + arg0
302+
compiled_kernel = triton_poi_fused_cat_0[(1,)](
303+
arg1, arg2, buf0, arg0, xnumel, XBLOCK=32
304+
)
305+
launcher = StaticallyLaunchedCudaKernel(compiled_kernel)
306+
307+
launcher.write_cubin_to_file(self.tmp_file.name)
308+
launcher.load_kernel()
309+
device_interface = get_interface_for_device("cuda")
310+
stream = device_interface.get_raw_stream(device_interface.current_device())
311+
312+
launcher.run((1, 1, 1), stream, arg1, arg2, buf1, arg0, xnumel)
313+
self.assertEqual(buf0, buf1)
314+
315+
316+
if __name__ == "__main__":
317+
from torch._inductor.test_case import run_tests
318+
319+
run_tests()

torch/_C/__init__.pyi.in

+25
Original file line numberDiff line numberDiff line change
@@ -2545,3 +2545,28 @@ class _NodeIter(Iterator):
25452545
def __init__(self, root: FxNode, reversed: _bool) -> None: ...
25462546
def __iter__(self) -> Iterator[FxNode]: ...
25472547
def __next__(self) -> FxNode: ...
2548+
2549+
2550+
# Defined in torch/csrc/inductor/static_cuda_launcher.cpp
2551+
class _StaticCudaLauncher:
2552+
@staticmethod
2553+
def _load_kernel(
2554+
cubin_file: str,
2555+
func_name: str,
2556+
shared_mem_bytes: _int,
2557+
) -> Tuple[_int, _int, _int]:
2558+
...
2559+
2560+
@staticmethod
2561+
def _launch_kernel(
2562+
func: _int,
2563+
grid_x: _int,
2564+
grid_y: _int,
2565+
grid_z: _int,
2566+
num_warps: _int,
2567+
shared_mem_bytes: _int,
2568+
arg_types: str,
2569+
args: Tuple[Any, ...],
2570+
stream: _int,
2571+
) -> None:
2572+
...

0 commit comments

Comments
 (0)