From 628ff0bc4f0cf1f27ba5e03932c66378e6e34593 Mon Sep 17 00:00:00 2001 From: Mogball Date: Thu, 14 Nov 2024 22:27:19 -0800 Subject: [PATCH] [Tutorial] Remove incorrect caching from softmax tutorial The fused softmax implementation in the tutorial precompiles the kernel to query the register usage of the kernel, based on the parameters used to specialize the kernel. On top of this, it implements a simple caching system for this step based on just the block size. As noted in https://github.com/triton-lang/triton/issues/4739, this caching is incorrect, because it's also not keyed on the `num_stages` constexpr argument or the shapes of the tensors. Since triton already has its own JIT compilation cache, and this caching bit is not really relevant to the tutorial, just remove it to get rid of the footgun. --- python/tutorials/02-fused-softmax.py | 55 +++++++++++++--------------- 1 file changed, 26 insertions(+), 29 deletions(-) diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index d08afb1e59d2..c98042559770 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -139,35 +139,32 @@ def softmax(x): y = torch.empty_like(x) # pre-compile kernel to get register usage and compute thread occupancy. - kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0)) - if kernel is None: - kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, - num_stages=num_stages, num_warps=num_warps, grid=(1, )) - kernel._init_handles() - n_regs = kernel.n_regs - size_smem = kernel.metadata.shared - if is_hip(): - # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available. - # However, this is not always the case. In most cases all registers can be used as regular purpose registers. - # ISA SECTION (3.6.4 for CDNA3) - # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used - # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total - # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is - # not required to be equal numbers of both types. - if is_cdna(): - NUM_GPRS = NUM_REGS * 2 - - # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor. - # When we divide this number with WARP_SIZE we get maximum number of waves that can - # execute on a CU (multi-processor) in parallel. - MAX_NUM_THREADS = properties["max_threads_per_sm"] - max_num_waves = MAX_NUM_THREADS // WARP_SIZE - occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps - else: - occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps) - occupancy = min(occupancy, SIZE_SMEM // size_smem) - num_programs = NUM_SM * occupancy - kernels[BLOCK_SIZE] = (kernel, num_programs) + kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, + num_stages=num_stages, num_warps=num_warps, grid=(1, )) + kernel._init_handles() + n_regs = kernel.n_regs + size_smem = kernel.metadata.shared + if is_hip(): + # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available. + # However, this is not always the case. In most cases all registers can be used as regular purpose registers. + # ISA SECTION (3.6.4 for CDNA3) + # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used + # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total + # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is + # not required to be equal numbers of both types. + if is_cdna(): + NUM_GPRS = NUM_REGS * 2 + + # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor. + # When we divide this number with WARP_SIZE we get maximum number of waves that can + # execute on a CU (multi-processor) in parallel. + MAX_NUM_THREADS = properties["max_threads_per_sm"] + max_num_waves = MAX_NUM_THREADS // WARP_SIZE + occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps + else: + occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps) + occupancy = min(occupancy, SIZE_SMEM // size_smem) + num_programs = NUM_SM * occupancy num_programs = min(num_programs, n_rows)