|
| 1 | +# Owner(s): ["module: inductor"] |
| 2 | +import os |
| 3 | +import tempfile |
| 4 | +from typing import Any, Callable |
| 5 | + |
| 6 | +import torch |
| 7 | +from torch._dynamo.device_interface import get_interface_for_device |
| 8 | +from torch._inductor.runtime import triton_helpers |
| 9 | +from torch._inductor.runtime.static_cuda_launcher import StaticallyLaunchedCudaKernel |
| 10 | +from torch._inductor.runtime.triton_compat import tl, triton |
| 11 | +from torch._inductor.runtime.triton_helpers import libdevice |
| 12 | +from torch._inductor.test_case import TestCase |
| 13 | +from torch.testing._internal.common_utils import skipIfRocm |
| 14 | +from torch.testing._internal.triton_utils import requires_cuda |
| 15 | + |
| 16 | + |
| 17 | +@requires_cuda |
| 18 | +class TestStaticCudaLauncher(TestCase): |
| 19 | + def setUp(self): |
| 20 | + # Create a temporary file to store the cubin. |
| 21 | + # We set delete=False so that the file persists after closing. |
| 22 | + self.tmp_file = tempfile.NamedTemporaryFile(mode="wb") |
| 23 | + self.tmp_file.close() # Close now; we'll open it for writing later. |
| 24 | + super().setUp() |
| 25 | + |
| 26 | + def tearDown(self): |
| 27 | + super().tearDown() |
| 28 | + # Delete the temporary cubin file. |
| 29 | + try: |
| 30 | + os.remove(self.tmp_file.name) |
| 31 | + except FileNotFoundError: |
| 32 | + pass |
| 33 | + |
| 34 | + def _make_launcher( |
| 35 | + self, |
| 36 | + kernel: Callable, |
| 37 | + args: tuple[Any, ...], |
| 38 | + grid: tuple[Any, ...] = (1,), |
| 39 | + ) -> StaticallyLaunchedCudaKernel: |
| 40 | + """ |
| 41 | + Compiles a Triton kernel with the provided *args, |
| 42 | + writes its cubin to the temporary file, and returns the file path. |
| 43 | + """ |
| 44 | + fn = triton.jit(kernel) |
| 45 | + # Launch the kernel to trigger compilation. |
| 46 | + compiled_kernel = fn[grid](*args) |
| 47 | + result = StaticallyLaunchedCudaKernel(compiled_kernel) |
| 48 | + result.write_cubin_to_file(self.tmp_file.name) |
| 49 | + result.load_kernel() |
| 50 | + return result |
| 51 | + |
| 52 | + @skipIfRocm |
| 53 | + def test_basic(self): |
| 54 | + def simple_kernel(arg0, arg1): |
| 55 | + x = tl.load(arg0) |
| 56 | + y = arg1 |
| 57 | + tl.store(arg0, x + y) |
| 58 | + |
| 59 | + arg0 = torch.zeros(1, dtype=torch.int32, device="cuda") |
| 60 | + arg1 = 5 |
| 61 | + args = (arg0, arg1) |
| 62 | + |
| 63 | + launcher = self._make_launcher(simple_kernel, args, (1,)) |
| 64 | + self.assertEqual(arg0, torch.tensor([5], dtype=torch.int32, device="cuda")) |
| 65 | + self.assertEqual(launcher.arg_tys, "Oi") |
| 66 | + new_arg0 = torch.zeros(1, dtype=torch.int32, device="cuda") |
| 67 | + device_interface = get_interface_for_device("cuda") |
| 68 | + stream = device_interface.get_raw_stream(device_interface.current_device()) |
| 69 | + |
| 70 | + launcher.run((1,), stream, new_arg0, arg1) |
| 71 | + self.assertEqual(new_arg0, arg0) |
| 72 | + |
| 73 | + # I wish I could macro all int types this into a single unit test on a loop, but |
| 74 | + # 1. variables aren't allowed as type annotations in python |
| 75 | + # 2. triton relies on inspect.get_source to get the type annotations |
| 76 | + # so I can't even use exec() to generate the test cases. |
| 77 | + # So we'll just make a few kernels by hand |
| 78 | + @skipIfRocm |
| 79 | + def test_unsigned_integers(self): |
| 80 | + def unsigned_integers( |
| 81 | + arg0, arg1: tl.uint8, arg2: tl.uint16, arg3: tl.uint32, arg4: tl.uint64 |
| 82 | + ): |
| 83 | + x = tl.load(arg0) |
| 84 | + y = arg1 + arg2 + arg3 + arg4 |
| 85 | + tl.store(arg0, x + y) |
| 86 | + |
| 87 | + arg0 = torch.zeros(1, dtype=torch.uint64, device="cuda") |
| 88 | + # Using small numbers creates a Literal type which triton treats as a constant |
| 89 | + args = (arg0, 50, 50, 50, 50) |
| 90 | + |
| 91 | + launcher = self._make_launcher(unsigned_integers, args, (1,)) |
| 92 | + self.assertEqual(arg0, torch.tensor([200], dtype=torch.uint64, device="cuda")) |
| 93 | + self.assertEqual(launcher.arg_tys, "OBHIK") |
| 94 | + new_arg0 = torch.zeros(1, dtype=torch.uint64, device="cuda") |
| 95 | + device_interface = get_interface_for_device("cuda") |
| 96 | + stream = device_interface.get_raw_stream(device_interface.current_device()) |
| 97 | + launcher.run((1,), stream, new_arg0, 50, 50, 50, 50) |
| 98 | + self.assertEqual(new_arg0, arg0) |
| 99 | + |
| 100 | + @skipIfRocm |
| 101 | + def test_signed_integers(self): |
| 102 | + def signed_integers( |
| 103 | + arg0, arg1: tl.int8, arg2: tl.int16, arg3: tl.int32, arg4: tl.int64 |
| 104 | + ): |
| 105 | + x = tl.load(arg0) |
| 106 | + y = arg1 + arg2 + arg3 + arg4 |
| 107 | + tl.store(arg0, x + y) |
| 108 | + |
| 109 | + arg0 = torch.zeros(1, dtype=torch.int64, device="cuda") |
| 110 | + # Using small numbers creates a Literal type which triton treats as a constant |
| 111 | + args = (arg0, 50, 50, 50, 50) |
| 112 | + |
| 113 | + launcher = self._make_launcher(signed_integers, args, (1,)) |
| 114 | + self.assertEqual(arg0, torch.tensor([200], dtype=torch.int64, device="cuda")) |
| 115 | + self.assertEqual(launcher.arg_tys, "Obhil") |
| 116 | + new_arg0 = torch.zeros(1, dtype=torch.int64, device="cuda") |
| 117 | + device_interface = get_interface_for_device("cuda") |
| 118 | + stream = device_interface.get_raw_stream(device_interface.current_device()) |
| 119 | + launcher.run((1,), stream, new_arg0, 50, 50, 50, 50) |
| 120 | + self.assertEqual(new_arg0, arg0) |
| 121 | + |
| 122 | + # TODO: floats don't work properly, triton seems to think they're all tl.float32 |
| 123 | + # despite type annotations. |
| 124 | + # There's also not really a good way for me to make a float16 in python... |
| 125 | + @skipIfRocm |
| 126 | + def test_floats(self): |
| 127 | + def floats(arg0, arg1: tl.float16, arg2: tl.float32, arg3: tl.float64): |
| 128 | + x = tl.load(arg0) |
| 129 | + y = arg1 + arg2 + arg3 |
| 130 | + tl.store(arg0, x + y) |
| 131 | + |
| 132 | + arg0 = torch.zeros(1, dtype=torch.float64, device="cuda") |
| 133 | + |
| 134 | + args = (arg0, 1.0, 1.0, 1.0) |
| 135 | + |
| 136 | + launcher = self._make_launcher(floats, args, (1,)) |
| 137 | + # TODO: in Pytorch's pinned version of triton, arg3 is typed as regular float |
| 138 | + # but in triton 3.3.0, this is fixed and it's 0ffd. We'll need to update later. |
| 139 | + self.assertEqual(launcher.arg_tys, "Offf") |
| 140 | + self.assertEqual(arg0, torch.tensor([3.0], dtype=torch.float64, device="cuda")) |
| 141 | + new_arg0 = torch.zeros(1, dtype=torch.float64, device="cuda") |
| 142 | + device_interface = get_interface_for_device("cuda") |
| 143 | + stream = device_interface.get_raw_stream(device_interface.current_device()) |
| 144 | + launcher.run((1,), stream, new_arg0, 1.0, 1.0, 1.0) |
| 145 | + self.assertEqual(new_arg0, arg0) |
| 146 | + |
| 147 | + @skipIfRocm |
| 148 | + def test_basic_1arg(self): |
| 149 | + def simple_kernel_1_arg(arg0): |
| 150 | + x = tl.load(arg0) |
| 151 | + tl.store(arg0, x + 1) |
| 152 | + |
| 153 | + arg0 = torch.zeros(1, dtype=torch.int32, device="cuda") |
| 154 | + launcher = self._make_launcher(simple_kernel_1_arg, (arg0,), (1,)) |
| 155 | + self.assertEqual(arg0, torch.tensor([1], dtype=torch.int32, device="cuda")) |
| 156 | + self.assertEqual(launcher.arg_tys, "O") |
| 157 | + new_arg0 = torch.zeros(1, dtype=torch.int32, device="cuda") |
| 158 | + device_interface = get_interface_for_device("cuda") |
| 159 | + stream = device_interface.get_raw_stream(device_interface.current_device()) |
| 160 | + |
| 161 | + launcher.run( |
| 162 | + (1,), |
| 163 | + stream, |
| 164 | + new_arg0, |
| 165 | + ) |
| 166 | + self.assertEqual(new_arg0, arg0) |
| 167 | + |
| 168 | + @skipIfRocm |
| 169 | + def test_constexpr(self): |
| 170 | + # Constexprs are compiled directly into the cubin file, |
| 171 | + # so we never need to pass it to StaticCudaLauncher. |
| 172 | + |
| 173 | + @triton.jit |
| 174 | + def kernel_constexpr(arg0, CONSTANT: tl.constexpr): |
| 175 | + x = tl.load(arg0) |
| 176 | + tl.store(arg0, x + CONSTANT) |
| 177 | + |
| 178 | + # Can't use make_launcher because constexpr needs to be constant |
| 179 | + arg0 = torch.zeros(1, dtype=torch.int32, device="cuda") |
| 180 | + compiled_kernel = kernel_constexpr[(1,)](arg0, CONSTANT=5) |
| 181 | + launcher = StaticallyLaunchedCudaKernel(compiled_kernel) |
| 182 | + launcher.write_cubin_to_file(self.tmp_file.name) |
| 183 | + launcher.load_kernel() |
| 184 | + |
| 185 | + self.assertEqual(arg0, torch.tensor([5], dtype=torch.int32, device="cuda")) |
| 186 | + self.assertEqual(launcher.arg_tys, "O") |
| 187 | + new_arg0 = torch.zeros(1, dtype=torch.int32, device="cuda") |
| 188 | + device_interface = get_interface_for_device("cuda") |
| 189 | + stream = device_interface.get_raw_stream(device_interface.current_device()) |
| 190 | + launcher.run( |
| 191 | + (1,), |
| 192 | + stream, |
| 193 | + new_arg0, |
| 194 | + ) |
| 195 | + self.assertEqual(new_arg0, arg0) |
| 196 | + |
| 197 | + @skipIfRocm |
| 198 | + def test_implied_constant(self): |
| 199 | + """xnumel is unused in this kernel, but isn't explicitly marked as a constexpr""" |
| 200 | + |
| 201 | + # This kernel was generated by inductor so it has a bunch of unused arguments. We don't change it |
| 202 | + @triton.jit |
| 203 | + def triton_red_fused_any_isinf_0( |
| 204 | + in_ptr0, |
| 205 | + out_ptr0, |
| 206 | + xnumel, # noqa: F841 |
| 207 | + r0_numel, |
| 208 | + XBLOCK: tl.constexpr, |
| 209 | + R0_BLOCK: tl.constexpr, |
| 210 | + ): |
| 211 | + xnumel = 1 # noqa: F841 |
| 212 | + rnumel = r0_numel # noqa: F841 |
| 213 | + RBLOCK: tl.constexpr = R0_BLOCK # noqa: F841 |
| 214 | + xoffset = tl.program_id(0) * XBLOCK |
| 215 | + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] # noqa: F841 |
| 216 | + xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) # noqa: F841 |
| 217 | + r0_base = tl.arange(0, R0_BLOCK)[None, :] |
| 218 | + rbase = r0_base # noqa: F841 |
| 219 | + _tmp3 = tl.full([XBLOCK, R0_BLOCK], False, tl.int1) |
| 220 | + for r0_offset in range(0, r0_numel, R0_BLOCK): |
| 221 | + r0_index = r0_offset + r0_base |
| 222 | + r0_mask = r0_index < r0_numel |
| 223 | + roffset = r0_offset # noqa: F841 |
| 224 | + rindex = r0_index # noqa: F841 |
| 225 | + r0_0 = r0_index |
| 226 | + tmp0 = tl.load( |
| 227 | + in_ptr0 + (r0_0), r0_mask, eviction_policy="evict_first", other=0.0 |
| 228 | + ) |
| 229 | + tmp1 = libdevice.isinf(tmp0).to(tl.int1) |
| 230 | + tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) |
| 231 | + tmp4 = _tmp3 | tmp2 |
| 232 | + _tmp3 = tl.where(r0_mask, tmp4, _tmp3) |
| 233 | + tmp3 = triton_helpers.any(_tmp3.to(tl.int8), 1)[:, None].to(tl.int1) |
| 234 | + tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp3, None) |
| 235 | + |
| 236 | + arg0 = torch.tensor([0.0, 0.5, float("inf"), 5], device="cuda") |
| 237 | + arg1 = torch.tensor([False], device="cuda") |
| 238 | + arg2 = torch.tensor([False], device="cuda") |
| 239 | + compiled_kernel = triton_red_fused_any_isinf_0[1,]( |
| 240 | + arg0, arg1, 1, 128, XBLOCK=1, R0_BLOCK=1 |
| 241 | + ) |
| 242 | + |
| 243 | + launcher = StaticallyLaunchedCudaKernel(compiled_kernel) |
| 244 | + launcher.write_cubin_to_file(self.tmp_file.name) |
| 245 | + launcher.load_kernel() |
| 246 | + |
| 247 | + device_interface = get_interface_for_device("cuda") |
| 248 | + stream = device_interface.get_raw_stream(device_interface.current_device()) |
| 249 | + launcher.run((1,), stream, arg0, arg2, 1, 128) |
| 250 | + self.assertEqual(arg1, arg2) |
| 251 | + |
| 252 | + @skipIfRocm |
| 253 | + def test_kernel_empty_tensor(self): |
| 254 | + # Triton kernel generated by torch.compile of the following: |
| 255 | + # @torch.compile() |
| 256 | + # def foo(x, y): |
| 257 | + # return torch.cat(((x * 4), y + 10)) |
| 258 | + |
| 259 | + # Running with example input: |
| 260 | + # torch._dynamo.decorators.mark_unbacked(t, 0) |
| 261 | + # x = torch.rand(0, device="cuda") |
| 262 | + # y = torch.rand(20, device="cuda") |
| 263 | + |
| 264 | + @triton.jit |
| 265 | + def triton_poi_fused_cat_0( |
| 266 | + in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK: tl.constexpr |
| 267 | + ): |
| 268 | + xoffset = tl.program_id(0).to(tl.int64) * XBLOCK |
| 269 | + xindex = xoffset + tl.arange(0, XBLOCK)[:].to(tl.int64) |
| 270 | + xmask = xindex < xnumel |
| 271 | + x0 = xindex |
| 272 | + tmp0 = x0 |
| 273 | + tmp3 = ks0 |
| 274 | + tmp4 = tmp0 < tmp3 |
| 275 | + tmp5 = tl.load( |
| 276 | + in_ptr0 + (x0), xmask & tmp4, eviction_policy="evict_last", other=0.0 |
| 277 | + ) |
| 278 | + tmp6 = 4.0 |
| 279 | + tmp7 = tmp5 * tmp6 |
| 280 | + tmp8 = tl.full(tmp7.shape, 0.0, tmp7.dtype) |
| 281 | + tmp9 = tl.where(tmp4, tmp7, tmp8) |
| 282 | + tmp10 = tmp0 >= tmp3 |
| 283 | + tmp13 = tl.load( |
| 284 | + in_ptr1 + (x0 + ((-1) * ks0)), |
| 285 | + xmask & tmp10, |
| 286 | + eviction_policy="evict_last", |
| 287 | + other=0.0, |
| 288 | + ) |
| 289 | + tmp14 = 10.0 |
| 290 | + tmp15 = tmp13 + tmp14 |
| 291 | + tmp16 = tl.full(tmp15.shape, 0.0, tmp15.dtype) |
| 292 | + tmp17 = tl.where(tmp10, tmp15, tmp16) |
| 293 | + tmp18 = tl.where(tmp4, tmp9, tmp17) |
| 294 | + tl.store(out_ptr0 + (x0), tmp18, xmask) |
| 295 | + |
| 296 | + arg0 = 0 |
| 297 | + arg1 = torch.randn(0, device="cuda") |
| 298 | + arg2 = torch.randn(20, device="cuda") |
| 299 | + buf0 = torch.empty(20, device="cuda") |
| 300 | + buf1 = torch.empty(20, device="cuda") |
| 301 | + xnumel = 20 + arg0 |
| 302 | + compiled_kernel = triton_poi_fused_cat_0[(1,)]( |
| 303 | + arg1, arg2, buf0, arg0, xnumel, XBLOCK=32 |
| 304 | + ) |
| 305 | + launcher = StaticallyLaunchedCudaKernel(compiled_kernel) |
| 306 | + |
| 307 | + launcher.write_cubin_to_file(self.tmp_file.name) |
| 308 | + launcher.load_kernel() |
| 309 | + device_interface = get_interface_for_device("cuda") |
| 310 | + stream = device_interface.get_raw_stream(device_interface.current_device()) |
| 311 | + |
| 312 | + launcher.run((1, 1, 1), stream, arg1, arg2, buf1, arg0, xnumel) |
| 313 | + self.assertEqual(buf0, buf1) |
| 314 | + |
| 315 | + |
| 316 | +if __name__ == "__main__": |
| 317 | + from torch._inductor.test_case import run_tests |
| 318 | + |
| 319 | + run_tests() |
0 commit comments