From bb896d2326583726f950e1963704fff2529a65bb Mon Sep 17 00:00:00 2001 From: Alex Brown Date: Wed, 2 Jul 2025 21:38:47 -0500 Subject: [PATCH] Ryan's Origami changes to enable heuristics that account for custom kernels --- .../lib/source/analytical/AnalyticalGemm.cpp | 111 +++++++++++++++++- 1 file changed, 107 insertions(+), 4 deletions(-) diff --git a/projects/hipblaslt/tensilelite/Tensile/Source/lib/source/analytical/AnalyticalGemm.cpp b/projects/hipblaslt/tensilelite/Tensile/Source/lib/source/analytical/AnalyticalGemm.cpp index 787ef99165b..0b179cf1a83 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Source/lib/source/analytical/AnalyticalGemm.cpp +++ b/projects/hipblaslt/tensilelite/Tensile/Source/lib/source/analytical/AnalyticalGemm.cpp @@ -531,10 +531,10 @@ namespace TensileLite double L_tile_total = (L_tile_single * num_iter) + L_prologue + L_epilogue + L_WG_setup + (28 * num_iter); //Iteration branch latency - if(MT_K == 512) - { - L_tile_total *= 1.5; - } + // if(MT_K == 512) + // { + // L_tile_total *= 1.5; + // } if(debug || Hardware::is_debug_enabled()) { @@ -799,6 +799,29 @@ namespace TensileLite element_size_out, mx_block_size, debug); + + //Enable Customized Heuristics. + bool enable_heuristics = true; + + if(enable_heuristics) + { + //Check if tile size is bigger than problem dimension. Invalidate it if it's not MI dimension (smallest possible dimension) + double edge_waste_m = 1 - (static_cast(M % MT_M) / MT_M); + double edge_waste_n = 1 - (static_cast(N % MT_N) / MT_N); + if(((MT_M > M && MT_M != MI_M && edge_waste_m > 0.5) + || (MT_N > N && MT_N != MI_N && edge_waste_n > 0.5)) + || (MT_K > K && MT_K != MI_K)) + { + return std::numeric_limits::max(); + } + + // //Set a minimum number of K iterations (avoid big K tile on Skinny K) + // if(safe_ceil_div(K, MT_K) < 4 && K > 128) + // { + // return std::numeric_limits::max(); + // } + } + // Compute latency for all waves and return it as the latency for the MT/problem double total_latency = L_wave * N_waves; //total_latency = total_latency + prologue_latency; @@ -806,6 +829,86 @@ namespace TensileLite { hardware.print_debug_info(); } + + if(enable_heuristics) + { + if(MT_M == 256 && MT_N == 256 && MT_K == 128 && (element_size_A == 8)) + { + //The kernel for this is more optimized + total_latency = total_latency * 0.8; + } + + if(MT_M == 256 && MT_N == 16 && MT_K == 128 && (element_size_A == 16)) + { + //The kernel for this is less optimized, for some reason + total_latency = total_latency * 2; + } + + if(MT_M == 16 && MT_N == 256 && MT_K == 128 && (element_size_A == 16)) + { + //The kernel for this is less optimized, for some reason + total_latency = total_latency * 2; + } + + //Bias Model towards at least one dim being power of 2 + bool MT_M_is_power_two = (MT_M > 0) && ((MT_M & (MT_M - 1)) == 0); + bool MT_N_is_power_two = (MT_N > 0) && ((MT_N & (MT_N - 1)) == 0); + if(!MT_M_is_power_two && !MT_N_is_power_two) + { + total_latency = total_latency * 1.5; + } + + //Bias Model towards both dims being a power of 2 + if(MT_M_is_power_two && MT_N_is_power_two) + { + total_latency = total_latency * 0.9; + } + + //Bias towards dimensions divisible by 64 for 8-bit datatypes + //This biases the other dimensions to being divisible by 64 bytes + if((MT_M > 64) && (MT_M % 64 != 0) && (element_size_A == 8)) + { + total_latency = total_latency * 1.2; + } + + //Bias towards dimensions divisible by 64 for 8-bit datatypes + //This biases the other dimensions to being divisible by 64 bytes + if((MT_N > 64) && (MT_N % 64 != 0) && (element_size_A == 8)) + { + total_latency = total_latency * 1.2; + } + + //Bias towards dimensions divisible by 32 for 16-bit datatypes + //This biases the other dimensions to being divisible by 64 bytes + if((MT_M > 32) && (MT_M % 32 != 0) && (element_size_A == 16)) + { + total_latency = total_latency * 1.5; + } + + //Bias towards dimensions divisible by 32 for 16-bit datatypes + //This biases the other dimensions to being divisible by 64 bytes + if((MT_N > 32) && (MT_N % 32 != 0) && (element_size_A == 16)) + { + total_latency = total_latency * 1.5; + } + + //Bias towards dimensions 32 or larger for bf16 datatypes + if(((MT_M >= 32) || MT_N >= 32) && (element_size_A == 16)) + { + total_latency = total_latency * 0.9; + } + } + //If we can still fit one whole dimension in a singletile though, that's great! promote that + // if(MT_M >= M || MT_N >= N) + // { + // total_latency = total_latency * 0.8; + // } + + // if(MT_M_is_power_two && MT_N_is_power_two) + // { + // total_latency = total_latency * 0.95; + // } + return total_latency; }