Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -531,10 +531,10 @@ namespace TensileLite
double L_tile_total = (L_tile_single * num_iter) + L_prologue + L_epilogue + L_WG_setup
+ (28 * num_iter); //Iteration branch latency

if(MT_K == 512)
{
L_tile_total *= 1.5;
}
// if(MT_K == 512)
// {
// L_tile_total *= 1.5;
// }

if(debug || Hardware::is_debug_enabled())
{
Expand Down Expand Up @@ -799,13 +799,116 @@ namespace TensileLite
element_size_out,
mx_block_size,
debug);

//Enable Customized Heuristics.
bool enable_heuristics = true;

if(enable_heuristics)
{
//Check if tile size is bigger than problem dimension. Invalidate it if it's not MI dimension (smallest possible dimension)
double edge_waste_m = 1 - (static_cast<double>(M % MT_M) / MT_M);
double edge_waste_n = 1 - (static_cast<double>(N % MT_N) / MT_N);
if(((MT_M > M && MT_M != MI_M && edge_waste_m > 0.5)
|| (MT_N > N && MT_N != MI_N && edge_waste_n > 0.5))
|| (MT_K > K && MT_K != MI_K))
{
return std::numeric_limits<double>::max();
}

// //Set a minimum number of K iterations (avoid big K tile on Skinny K)
// if(safe_ceil_div(K, MT_K) < 4 && K > 128)
// {
// return std::numeric_limits<double>::max();
// }
}

// Compute latency for all waves and return it as the latency for the MT/problem
double total_latency = L_wave * N_waves;
//total_latency = total_latency + prologue_latency;
if(Hardware::is_debug_enabled())
{
hardware.print_debug_info();
}

if(enable_heuristics)
{
if(MT_M == 256 && MT_N == 256 && MT_K == 128 && (element_size_A == 8))
{
//The kernel for this is more optimized
total_latency = total_latency * 0.8;
}

if(MT_M == 256 && MT_N == 16 && MT_K == 128 && (element_size_A == 16))
{
//The kernel for this is less optimized, for some reason
total_latency = total_latency * 2;
}

if(MT_M == 16 && MT_N == 256 && MT_K == 128 && (element_size_A == 16))
{
//The kernel for this is less optimized, for some reason
total_latency = total_latency * 2;
}

//Bias Model towards at least one dim being power of 2
bool MT_M_is_power_two = (MT_M > 0) && ((MT_M & (MT_M - 1)) == 0);
bool MT_N_is_power_two = (MT_N > 0) && ((MT_N & (MT_N - 1)) == 0);
if(!MT_M_is_power_two && !MT_N_is_power_two)
{
total_latency = total_latency * 1.5;
}

//Bias Model towards both dims being a power of 2
if(MT_M_is_power_two && MT_N_is_power_two)
{
total_latency = total_latency * 0.9;
}

//Bias towards dimensions divisible by 64 for 8-bit datatypes
//This biases the other dimensions to being divisible by 64 bytes
if((MT_M > 64) && (MT_M % 64 != 0) && (element_size_A == 8))
{
total_latency = total_latency * 1.2;
}

//Bias towards dimensions divisible by 64 for 8-bit datatypes
//This biases the other dimensions to being divisible by 64 bytes
if((MT_N > 64) && (MT_N % 64 != 0) && (element_size_A == 8))
{
total_latency = total_latency * 1.2;
}

//Bias towards dimensions divisible by 32 for 16-bit datatypes
//This biases the other dimensions to being divisible by 64 bytes
if((MT_M > 32) && (MT_M % 32 != 0) && (element_size_A == 16))
{
total_latency = total_latency * 1.5;
}

//Bias towards dimensions divisible by 32 for 16-bit datatypes
//This biases the other dimensions to being divisible by 64 bytes
if((MT_N > 32) && (MT_N % 32 != 0) && (element_size_A == 16))
{
total_latency = total_latency * 1.5;
}

//Bias towards dimensions 32 or larger for bf16 datatypes
if(((MT_M >= 32) || MT_N >= 32) && (element_size_A == 16))
{
total_latency = total_latency * 0.9;
}
}
//If we can still fit one whole dimension in a singletile though, that's great! promote that
// if(MT_M >= M || MT_N >= N)
// {
// total_latency = total_latency * 0.8;
// }

// if(MT_M_is_power_two && MT_N_is_power_two)
// {
// total_latency = total_latency * 0.95;
// }

return total_latency;
}

Expand Down