diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index d08afb1e59d2..c98042559770 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -139,35 +139,32 @@ def softmax(x): y = torch.empty_like(x) # pre-compile kernel to get register usage and compute thread occupancy. - kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0)) - if kernel is None: - kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, - num_stages=num_stages, num_warps=num_warps, grid=(1, )) - kernel._init_handles() - n_regs = kernel.n_regs - size_smem = kernel.metadata.shared - if is_hip(): - # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available. - # However, this is not always the case. In most cases all registers can be used as regular purpose registers. - # ISA SECTION (3.6.4 for CDNA3) - # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used - # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total - # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is - # not required to be equal numbers of both types. - if is_cdna(): - NUM_GPRS = NUM_REGS * 2 - - # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor. - # When we divide this number with WARP_SIZE we get maximum number of waves that can - # execute on a CU (multi-processor) in parallel. - MAX_NUM_THREADS = properties["max_threads_per_sm"] - max_num_waves = MAX_NUM_THREADS // WARP_SIZE - occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps - else: - occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps) - occupancy = min(occupancy, SIZE_SMEM // size_smem) - num_programs = NUM_SM * occupancy - kernels[BLOCK_SIZE] = (kernel, num_programs) + kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, + num_stages=num_stages, num_warps=num_warps, grid=(1, )) + kernel._init_handles() + n_regs = kernel.n_regs + size_smem = kernel.metadata.shared + if is_hip(): + # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available. + # However, this is not always the case. In most cases all registers can be used as regular purpose registers. + # ISA SECTION (3.6.4 for CDNA3) + # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used + # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total + # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is + # not required to be equal numbers of both types. + if is_cdna(): + NUM_GPRS = NUM_REGS * 2 + + # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor. + # When we divide this number with WARP_SIZE we get maximum number of waves that can + # execute on a CU (multi-processor) in parallel. + MAX_NUM_THREADS = properties["max_threads_per_sm"] + max_num_waves = MAX_NUM_THREADS // WARP_SIZE + occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps + else: + occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps) + occupancy = min(occupancy, SIZE_SMEM // size_smem) + num_programs = NUM_SM * occupancy num_programs = min(num_programs, n_rows)