From 9710f1ee0d7bf96707a445bc1c8716596c418f25 Mon Sep 17 00:00:00 2001 From: Pieter Ghysels Date: Tue, 9 Dec 2025 14:47:40 -0600 Subject: [PATCH 1/7] Few minor Origami changes --- shared/origami/src/origami/gemm.cpp | 72 ++++++++++++++++------------- 1 file changed, 41 insertions(+), 31 deletions(-) diff --git a/shared/origami/src/origami/gemm.cpp b/shared/origami/src/origami/gemm.cpp index 2404ced4254..1b952fdc804 100644 --- a/shared/origami/src/origami/gemm.cpp +++ b/shared/origami/src/origami/gemm.cpp @@ -363,7 +363,7 @@ double estimate_l2_hit(const problem_t& problem, // Use size_t for dimensions and counts to ensure type safety. const size_t workgroups_m = math::safe_ceil_div(problem.size.m, config.mt.m); const size_t workgroups_n = math::safe_ceil_div(problem.size.n, config.mt.n); - const size_t total_workgroups = workgroups_m * workgroups_n; + const size_t total_workgroups = workgroups_m * workgroups_n * problem.batch; // Concurrently executing workgroups are limited by the number of CUs.a const size_t concurrent_workgroups = std::min(total_workgroups, hardware.N_CU); @@ -377,7 +377,7 @@ double estimate_l2_hit(const problem_t& problem, std::max(math::safe_ceil_div(effective_cus, hardware.NUM_XCD), static_cast(1)); // Initial guess for the L2 tile dimensions (a tile of workgroups). - size_t l2_tile_n = std::min(static_cast(config.workgroup_mapping), workgroups_n); + size_t l2_tile_n = std::min(cu_per_xcd, std::min(static_cast(config.workgroup_mapping), workgroups_n)); size_t l2_tile_m = math::safe_ceil_div(cu_per_xcd, l2_tile_n); // Handle wrap-around case: if the tile is taller than the grid, wrap it to be wider. @@ -391,18 +391,22 @@ double estimate_l2_hit(const problem_t& problem, l2_tile_m = std::max(std::min(workgroups_m, l2_tile_m), static_cast(1)); l2_tile_n = std::max(std::min(workgroups_n, l2_tile_n), static_cast(1)); + double batches_per_xcd = std::min(static_cast(problem.batch), + std::max(static_cast(hardware.CU_per_L2) / (cu_per_xcd * splitting_factor), 1.)); + // Calculate memory footprint in bytes. const auto a_bytes = data_type_to_bytes(problem.a_dtype); const auto b_bytes = data_type_to_bytes(problem.b_dtype); auto calculate_footprint = [&](auto tile_m, auto tile_n) { auto a_footprint = tile_m * config.mt.mk() * a_bytes; auto b_footprint = tile_n * config.mt.nk() * b_bytes; - return a_footprint + b_footprint; + return (a_footprint + b_footprint) * batches_per_xcd; }; // Symmetrically shrink the L2 tile until it fits in the L2 cache capacity. // This is more robust than shrinking only one dimension. - while (calculate_footprint(l2_tile_m, l2_tile_n) > hardware.L2_capacity) { + while (calculate_footprint(l2_tile_m, l2_tile_n) > hardware.L2_capacity + || l2_tile_m * l2_tile_n > cu_per_xcd) { if (l2_tile_m > 1 && l2_tile_m >= l2_tile_n) { l2_tile_m--; } else if (l2_tile_n > 1) { @@ -447,9 +451,8 @@ double estimate_mall_hit(const problem_t& problem, // --- Initial Tile Sizing based on Concurrency --- // Use ceiling division for a more accurate initial guess. - size_t mall_tile_m = - math::safe_ceil_div(num_active_cus, static_cast(config.workgroup_mapping)); size_t mall_tile_n = std::min(static_cast(config.workgroup_mapping), workgroups_n); + size_t mall_tile_m = math::safe_ceil_div(num_active_cus, mall_tile_n); // Handle wrap-around case if the tile is taller than the grid. if (mall_tile_m > workgroups_m) { @@ -497,7 +500,8 @@ double estimate_mall_hit(const problem_t& problem, double compute_l2_hit_rate_global(const problem_t& problem, const hardware_t& hardware, const config_t& config, - size_t l2_capacity_bytes) { + size_t l2_capacity_bytes, + size_t splitting_factor) { // --- Hardware Parameters (as requested, defined locally) --- // You would normally get l2_capacity_bytes from your hardware_t struct. if (l2_capacity_bytes == 0) throw std::runtime_error("L2 Capacity is zero"); @@ -516,7 +520,8 @@ double compute_l2_hit_rate_global(const problem_t& problem, const double a_working_set = static_cast(grid_m * config.mt.mk()) * a_bytes; const double b_working_set = static_cast(grid_n * config.mt.nk()) * b_bytes; - const double total_working_set_bytes = a_working_set + b_working_set; + const double concurrent_batches = std::min(problem.batch, std::max(hardware.N_CU / (splitting_factor * grid_m * grid_n), static_cast(1))); + const double total_working_set_bytes = (a_working_set + b_working_set) * concurrent_batches; // 3. CRUCIAL: Check if the working set fits in the L2 cache. // If it doesn't, the global reuse pattern is broken by capacity misses, @@ -581,8 +586,9 @@ double compute_memory_latency(const problem_t& problem, // Global cap on L2 hit-rate (prevents impossible cache residency claims) // (Assumes capacity is given in KiB, convert to bytes) + // TODO hardware.L2_capacity is already in bytes? double H_mem1_global = - compute_l2_hit_rate_global(problem, hardware, config, hardware.L2_capacity * 1024); + compute_l2_hit_rate_global(problem, hardware, config, hardware.L2_capacity * 1024, splitting_factor); H_mem1 = std::min(H_mem1, H_mem1_global); @@ -592,22 +598,10 @@ double compute_memory_latency(const problem_t& problem, double H_mem2 = estimate_mall_hit(problem, hardware, config, num_active_cus, splitting_factor); // 3) Total loads are loads from A and loads from B - size_t MT_M_rounded_128bytes = round_elements_to_128B(MT_M, a_bits); - size_t MT_N_rounded_128bytes = round_elements_to_128B(MT_N, a_bits); - size_t MT_K_rounded_128bytes = round_elements_to_128B(MT_K, a_bits); - - if (!a_trans && !b_trans) { - MT_N_rounded_128bytes = MT_N; - MT_K_rounded_128bytes = MT_K; - } else if (a_trans && !b_trans) { - MT_M_rounded_128bytes = MT_M; - MT_N_rounded_128bytes = MT_N; - } else if (!a_trans && b_trans) { - MT_K_rounded_128bytes = MT_K; - } - - size_t Ld_A_value = MT_M_rounded_128bytes * MT_K_rounded_128bytes; - size_t Ld_B_value = MT_N_rounded_128bytes * MT_K_rounded_128bytes; + size_t Ld_A_value = a_trans ? + MT_M * round_elements_to_128B(MT_K, a_bits) : round_elements_to_128B(MT_M, a_bits) * MT_K; + size_t Ld_B_value = b_trans ? + round_elements_to_128B(MT_N, b_bits) * MT_K : MT_N * round_elements_to_128B(MT_K, b_bits); auto Ld_CU_bytes = (Ld_A_value * a_bytes) // A Bytes + (Ld_B_value * b_bytes); // B Bytes @@ -646,9 +640,8 @@ double compute_memory_latency(const problem_t& problem, // Calculate the tile of workgroups that can run concurrently (logic from estimate_mall_hit). size_t grid_m = math::safe_ceil_div(problem.size.m, MT_M); size_t grid_n = math::safe_ceil_div(problem.size.n, MT_N); - size_t mall_m = - math::safe_ceil_div(num_active_cus, static_cast(config.workgroup_mapping)); size_t mall_n = std::min(static_cast(config.workgroup_mapping), grid_n); + size_t mall_m = math::safe_ceil_div(num_active_cus, mall_n); // Handle wrap-around case if (mall_m > grid_m) { size_t num_wraps = (mall_m / grid_m); @@ -658,10 +651,12 @@ double compute_memory_latency(const problem_t& problem, // Clamp tile dimensions mall_m = std::max(std::min(grid_m, mall_m), static_cast(1)); mall_n = std::max(std::min(grid_n, mall_n), static_cast(1)); + double concurrent_batches = std::min(static_cast(problem.batch), + std::max(static_cast(num_active_cus) / (splitting_factor * mall_m * mall_n), 1.)); // This is the minimum unique bytes needed from HBM to feed the concurrent workgroups. - double min_load = static_cast((mall_m * config.mt.mk() * static_cast(a_bytes)) + - (mall_n * config.mt.nk() * static_cast(b_bytes))) * - batch; // Apply batching to the minimum load itself. + double min_load = static_cast((mall_m * Ld_A_value * static_cast(a_bytes)) + + (mall_n * Ld_B_value * static_cast(b_bytes))) * + concurrent_batches; // Apply batching to the minimum load itself. // The actual loads cannot be less than this physical minimum. Ld_MEM = std::max(Ld_MEM, min_load); Ld_mem2 = std::max(Ld_mem2, min_load); @@ -890,7 +885,7 @@ double compute_total_latency(const problem_t& problem, } // 1-1) To compute the latency, use default WGM. And WGM can't be greater than one - int defaultWGM = static_cast(ceil(std::sqrt(hardware.N_CU / hardware.NUM_XCD))); + int defaultWGM = batch > 1 ? 1 : static_cast(ceil(std::sqrt(hardware.N_CU / hardware.NUM_XCD))); auto config_with_default_wgm = config; config_with_default_wgm.workgroup_mapping = std::max(defaultWGM, 1); @@ -916,6 +911,21 @@ double compute_total_latency(const problem_t& problem, total_latency = total_latency * 10; } + // TODO is Origami missing something for the "Large N, very small M and K" case? + // In BBS NN gfx950, MT192x336x32 performs poorly on "Large N, very small M and K". + if(MT_M == 192 && MT_N == 336 && MT_K == 32 && !a_trans && !b_trans && problem.a_dtype == data_type_t::BFloat16) { + total_latency = total_latency * 10; + } + // In HHS TN gfx950, MT128x512x32 performs poorly on "Large N, very small M and K". + if(MT_M == 128 && MT_N == 512 && MT_K == 32 && a_trans && !b_trans && problem.a_dtype == data_type_t::Half) { + total_latency = total_latency * 10; + } + // In HHS NT gfx950, DP kernel with MT16x16x64 was added as a fallback kernel (VWA1 and VWB1), + // but it performs very poorly, especially on "Large K, very small M and N". + if(MT_M == 16 && MT_N == 16 && MT_K == 64 && !a_trans && b_trans && problem.a_dtype == data_type_t::Half) { + total_latency = total_latency * 10; + } + bool tf32_emu = ((problem.mi_dtype == data_type_t::XFloat32) && (hardware.arch == hardware_t::architecture_t::gfx950)); From 704af22d3ab53b76409a76838dfcdd437f6c631e Mon Sep 17 00:00:00 2001 From: Pieter Ghysels Date: Fri, 12 Dec 2025 11:12:15 -0600 Subject: [PATCH 2/7] Extra changes related to batch and splitting_factor --- shared/origami/src/origami/gemm.cpp | 60 +++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 16 deletions(-) diff --git a/shared/origami/src/origami/gemm.cpp b/shared/origami/src/origami/gemm.cpp index 1b952fdc804..9852b326038 100644 --- a/shared/origami/src/origami/gemm.cpp +++ b/shared/origami/src/origami/gemm.cpp @@ -363,7 +363,7 @@ double estimate_l2_hit(const problem_t& problem, // Use size_t for dimensions and counts to ensure type safety. const size_t workgroups_m = math::safe_ceil_div(problem.size.m, config.mt.m); const size_t workgroups_n = math::safe_ceil_div(problem.size.n, config.mt.n); - const size_t total_workgroups = workgroups_m * workgroups_n * problem.batch; + const size_t total_workgroups = workgroups_m * workgroups_n; // Concurrently executing workgroups are limited by the number of CUs.a const size_t concurrent_workgroups = std::min(total_workgroups, hardware.N_CU); @@ -372,7 +372,7 @@ double estimate_l2_hit(const problem_t& problem, // Number of CUs that might share the same K-tiles, adjusted for K-splitting. // This affects contention on the L2 cache partitions (XCDs). - const size_t effective_cus = math::safe_ceil_div(concurrent_workgroups, splitting_factor); + const size_t effective_cus = math::safe_ceil_div(concurrent_workgroups, splitting_factor * problem.batch); const size_t cu_per_xcd = std::max(math::safe_ceil_div(effective_cus, hardware.NUM_XCD), static_cast(1)); @@ -391,8 +391,8 @@ double estimate_l2_hit(const problem_t& problem, l2_tile_m = std::max(std::min(workgroups_m, l2_tile_m), static_cast(1)); l2_tile_n = std::max(std::min(workgroups_n, l2_tile_n), static_cast(1)); - double batches_per_xcd = std::min(static_cast(problem.batch), - std::max(static_cast(hardware.CU_per_L2) / (cu_per_xcd * splitting_factor), 1.)); + double tiles_per_xcd = std::min(static_cast(problem.batch * splitting_factor), + std::max(static_cast(hardware.CU_per_L2) / cu_per_xcd, 1.)); // Calculate memory footprint in bytes. const auto a_bytes = data_type_to_bytes(problem.a_dtype); @@ -401,6 +401,7 @@ double estimate_l2_hit(const problem_t& problem, auto a_footprint = tile_m * config.mt.mk() * a_bytes; auto b_footprint = tile_n * config.mt.nk() * b_bytes; return (a_footprint + b_footprint) * batches_per_xcd; + return (a_footprint + b_footprint) * tiles_per_xcd; }; // Symmetrically shrink the L2 tile until it fits in the L2 cache capacity. @@ -452,7 +453,7 @@ double estimate_mall_hit(const problem_t& problem, // --- Initial Tile Sizing based on Concurrency --- // Use ceiling division for a more accurate initial guess. size_t mall_tile_n = std::min(static_cast(config.workgroup_mapping), workgroups_n); - size_t mall_tile_m = math::safe_ceil_div(num_active_cus, mall_tile_n); + size_t mall_tile_m = math::safe_ceil_div(std::min(num_active_cus, workgroups_m * workgroups_n), mall_tile_n); // Handle wrap-around case if the tile is taller than the grid. if (mall_tile_m > workgroups_m) { @@ -475,6 +476,17 @@ double estimate_mall_hit(const problem_t& problem, return a_footprint + b_footprint; }; + while (mall_tile_m * mall_tile_n > num_active_cus) { + if (mall_tile_m > 1 && mall_tile_m >= mall_tile_n) { + mall_tile_m--; + } else if (mall_tile_n > 1) { + mall_tile_n--; + } else { + // Cannot shrink further. + break; + } + } + // --- Calculate Hit Rate based on the final, capacity-aware tile size --- const long long uncached_A_reads = static_cast(mall_tile_m) * config.mt.mk(); const long long uncached_B_reads = static_cast(mall_tile_n) * config.mt.nk(); @@ -520,8 +532,8 @@ double compute_l2_hit_rate_global(const problem_t& problem, const double a_working_set = static_cast(grid_m * config.mt.mk()) * a_bytes; const double b_working_set = static_cast(grid_n * config.mt.nk()) * b_bytes; - const double concurrent_batches = std::min(problem.batch, std::max(hardware.N_CU / (splitting_factor * grid_m * grid_n), static_cast(1))); - const double total_working_set_bytes = (a_working_set + b_working_set) * concurrent_batches; + const double concurrent_tiles = std::min(problem.batch * splitting_factor, std::max(hardware.N_CU / (grid_m * grid_n), static_cast(1))); + const double total_working_set_bytes = (a_working_set + b_working_set) * concurrent_tiles; // 3. CRUCIAL: Check if the working set fits in the L2 cache. // If it doesn't, the global reuse pattern is broken by capacity misses, @@ -641,7 +653,8 @@ double compute_memory_latency(const problem_t& problem, size_t grid_m = math::safe_ceil_div(problem.size.m, MT_M); size_t grid_n = math::safe_ceil_div(problem.size.n, MT_N); size_t mall_n = std::min(static_cast(config.workgroup_mapping), grid_n); - size_t mall_m = math::safe_ceil_div(num_active_cus, mall_n); + // size_t mall_m = math::safe_ceil_div(num_active_cus, mall_n); + size_t mall_m = math::safe_ceil_div(std::min(num_active_cus, grid_m * grid_n), mall_n); // Handle wrap-around case if (mall_m > grid_m) { size_t num_wraps = (mall_m / grid_m); @@ -651,12 +664,22 @@ double compute_memory_latency(const problem_t& problem, // Clamp tile dimensions mall_m = std::max(std::min(grid_m, mall_m), static_cast(1)); mall_n = std::max(std::min(grid_n, mall_n), static_cast(1)); - double concurrent_batches = std::min(static_cast(problem.batch), - std::max(static_cast(num_active_cus) / (splitting_factor * mall_m * mall_n), 1.)); + while (mall_m * mall_n > num_active_cus) { + if (mall_m > 1 && mall_m >= mall_n) { + mall_m--; + } else if (mall_n > 1) { + mall_n--; + } else { + // Cannot shrink further. + break; + } + } + double concurrent_tiles = std::min(static_cast(problem.batch /** splitting_factor*/), + std::max(static_cast(num_active_cus) / (mall_m * mall_n), 1.)); // This is the minimum unique bytes needed from HBM to feed the concurrent workgroups. double min_load = static_cast((mall_m * Ld_A_value * static_cast(a_bytes)) + - (mall_n * Ld_B_value * static_cast(b_bytes))) * - concurrent_batches; // Apply batching to the minimum load itself. + (mall_n * Ld_B_value * static_cast(b_bytes))) * + concurrent_tiles; // Apply batching to the minimum load itself. // The actual loads cannot be less than this physical minimum. Ld_MEM = std::max(Ld_MEM, min_load); Ld_mem2 = std::max(Ld_mem2, min_load); @@ -724,7 +747,7 @@ double compute_tile_latency(const problem_t& problem, double mem_bw_occ_limited = hardware.mem3_perf_ratio * mem_bw_occ; size_t MT_M_rounded_128bytes = round_elements_to_128B(MT_M, datatype_to_bits(problem.a_dtype)); - double L_epilogue = (static_cast(num_active_cus / splitting_factor) * + double L_epilogue = (static_cast(num_active_cus /*/ splitting_factor*/) * MT_M_rounded_128bytes * MT_N * d_bytes) / mem_bw_occ_limited; // One compute iteration happens in the prologue @@ -743,7 +766,11 @@ double compute_tile_latency(const problem_t& problem, // 4') K-split reductions are globally coherent, we need to write and read split-1 MT_M*MT_N // tiles to coherent memory if (splitting_factor > 1) { - size_t n_partials = splitting_factor - 1; + double mem_bw_occ = compute_mem_bw_from_occupancy(hardware, num_active_cus / splitting_factor); + double mem_bw_occ_limited = hardware.mem3_perf_ratio * mem_bw_occ; + + double n_partials = std::ceil(std::log2(static_cast(splitting_factor))); //splitting_factor - 1; + // size_t n_partials = splitting_factor - 1; // Only the reduction CU reads from all splits. double partial_read_bytes = @@ -758,7 +785,8 @@ double compute_tile_latency(const problem_t& problem, // 64 Threads active in a SIMD. Exposed to at least latency of reducing splitting_factor // tiles. double partial_adds = - (static_cast(config.mt.mn()) * static_cast(splitting_factor)) / (64); + // (static_cast(config.mt.mn()) * static_cast(splitting_factor)) / (64); + (static_cast(config.mt.mn()) * n_partials) / (64); double L_reduce = partial_readwrite_bytes / (mem_bw_occ_limited); L_epilogue += L_reduce + partial_adds + 10000; @@ -885,7 +913,7 @@ double compute_total_latency(const problem_t& problem, } // 1-1) To compute the latency, use default WGM. And WGM can't be greater than one - int defaultWGM = batch > 1 ? 1 : static_cast(ceil(std::sqrt(hardware.N_CU / hardware.NUM_XCD))); + int defaultWGM = batch > 1 ? 1 : static_cast(ceil(std::sqrt(max_cus / hardware.NUM_XCD))); auto config_with_default_wgm = config; config_with_default_wgm.workgroup_mapping = std::max(defaultWGM, 1); From e3590a80deb0ca7658fc3e02082ab7be9f1fa491 Mon Sep 17 00:00:00 2001 From: Pieter Ghysels Date: Tue, 16 Dec 2025 11:28:37 -0600 Subject: [PATCH 3/7] Apply splitting_factor in min_load, do not set hit rate to .5 if estimated as 0. --- shared/origami/src/origami/gemm.cpp | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/shared/origami/src/origami/gemm.cpp b/shared/origami/src/origami/gemm.cpp index 9852b326038..6e0c00d3009 100644 --- a/shared/origami/src/origami/gemm.cpp +++ b/shared/origami/src/origami/gemm.cpp @@ -604,7 +604,7 @@ double compute_memory_latency(const problem_t& problem, H_mem1 = std::min(H_mem1, H_mem1_global); - if (H_mem1 == 0) { H_mem1 = 0.5; } + // if (H_mem1 == 0) { H_mem1 = 0.5; } // 2) Estimate mall hit-rate double H_mem2 = estimate_mall_hit(problem, hardware, config, num_active_cus, splitting_factor); @@ -674,7 +674,7 @@ double compute_memory_latency(const problem_t& problem, break; } } - double concurrent_tiles = std::min(static_cast(problem.batch /** splitting_factor*/), + double concurrent_tiles = std::min(static_cast(problem.batch * splitting_factor), std::max(static_cast(num_active_cus) / (mall_m * mall_n), 1.)); // This is the minimum unique bytes needed from HBM to feed the concurrent workgroups. double min_load = static_cast((mall_m * Ld_A_value * static_cast(a_bytes)) + @@ -747,7 +747,7 @@ double compute_tile_latency(const problem_t& problem, double mem_bw_occ_limited = hardware.mem3_perf_ratio * mem_bw_occ; size_t MT_M_rounded_128bytes = round_elements_to_128B(MT_M, datatype_to_bits(problem.a_dtype)); - double L_epilogue = (static_cast(num_active_cus /*/ splitting_factor*/) * + double L_epilogue = (static_cast(num_active_cus / splitting_factor) * MT_M_rounded_128bytes * MT_N * d_bytes) / mem_bw_occ_limited; // One compute iteration happens in the prologue @@ -766,11 +766,7 @@ double compute_tile_latency(const problem_t& problem, // 4') K-split reductions are globally coherent, we need to write and read split-1 MT_M*MT_N // tiles to coherent memory if (splitting_factor > 1) { - double mem_bw_occ = compute_mem_bw_from_occupancy(hardware, num_active_cus / splitting_factor); - double mem_bw_occ_limited = hardware.mem3_perf_ratio * mem_bw_occ; - - double n_partials = std::ceil(std::log2(static_cast(splitting_factor))); //splitting_factor - 1; - // size_t n_partials = splitting_factor - 1; + size_t n_partials = splitting_factor - 1; // Only the reduction CU reads from all splits. double partial_read_bytes = @@ -785,8 +781,7 @@ double compute_tile_latency(const problem_t& problem, // 64 Threads active in a SIMD. Exposed to at least latency of reducing splitting_factor // tiles. double partial_adds = - // (static_cast(config.mt.mn()) * static_cast(splitting_factor)) / (64); - (static_cast(config.mt.mn()) * n_partials) / (64); + (static_cast(config.mt.mn()) * static_cast(splitting_factor)) / (64); double L_reduce = partial_readwrite_bytes / (mem_bw_occ_limited); L_epilogue += L_reduce + partial_adds + 10000; From 573b2fed465f4ba087af2a573b155ace7c5ccf56 Mon Sep 17 00:00:00 2001 From: Pieter Ghysels Date: Tue, 16 Dec 2025 15:02:17 -0600 Subject: [PATCH 4/7] Apply heuristics only for gfx950 --- shared/origami/src/origami/gemm.cpp | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/shared/origami/src/origami/gemm.cpp b/shared/origami/src/origami/gemm.cpp index 6e0c00d3009..866491af16e 100644 --- a/shared/origami/src/origami/gemm.cpp +++ b/shared/origami/src/origami/gemm.cpp @@ -934,19 +934,21 @@ double compute_total_latency(const problem_t& problem, total_latency = total_latency * 10; } - // TODO is Origami missing something for the "Large N, very small M and K" case? - // In BBS NN gfx950, MT192x336x32 performs poorly on "Large N, very small M and K". - if(MT_M == 192 && MT_N == 336 && MT_K == 32 && !a_trans && !b_trans && problem.a_dtype == data_type_t::BFloat16) { - total_latency = total_latency * 10; - } - // In HHS TN gfx950, MT128x512x32 performs poorly on "Large N, very small M and K". - if(MT_M == 128 && MT_N == 512 && MT_K == 32 && a_trans && !b_trans && problem.a_dtype == data_type_t::Half) { - total_latency = total_latency * 10; - } - // In HHS NT gfx950, DP kernel with MT16x16x64 was added as a fallback kernel (VWA1 and VWB1), - // but it performs very poorly, especially on "Large K, very small M and N". - if(MT_M == 16 && MT_N == 16 && MT_K == 64 && !a_trans && b_trans && problem.a_dtype == data_type_t::Half) { - total_latency = total_latency * 10; + if(hardware.arch == hardware_t::architecture_t::gfx950) { + // TODO is Origami missing something for the "Large N, very small M and K" case? + // In BBS NN gfx950, MT192x336x32 performs poorly on "Large N, very small M and K". + if(MT_M == 192 && MT_N == 336 && MT_K == 32 && !a_trans && !b_trans && problem.a_dtype == data_type_t::BFloat16) { + total_latency = total_latency * 10; + } + // In HHS TN gfx950, MT128x512x32 performs poorly on "Large N, very small M and K". + if(MT_M == 128 && MT_N == 512 && MT_K == 32 && a_trans && !b_trans && problem.a_dtype == data_type_t::Half) { + total_latency = total_latency * 10; + } + // In HHS NT gfx950, DP kernel with MT16x16x64 was added as a fallback kernel (VWA1 and VWB1), + // but it performs very poorly, especially on "Large K, very small M and N". + if(MT_M == 16 && MT_N == 16 && MT_K == 64 && !a_trans && b_trans && problem.a_dtype == data_type_t::Half) { + total_latency = total_latency * 10; + } } bool tf32_emu = ((problem.mi_dtype == data_type_t::XFloat32) && From 4be4cbef2095186844656f233f0e3ab995354b99 Mon Sep 17 00:00:00 2001 From: Pieter Ghysels Date: Thu, 18 Dec 2025 15:44:40 -0600 Subject: [PATCH 5/7] Minor updates --- shared/origami/src/origami/gemm.cpp | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/shared/origami/src/origami/gemm.cpp b/shared/origami/src/origami/gemm.cpp index 866491af16e..7d3f76fdc0a 100644 --- a/shared/origami/src/origami/gemm.cpp +++ b/shared/origami/src/origami/gemm.cpp @@ -377,7 +377,7 @@ double estimate_l2_hit(const problem_t& problem, std::max(math::safe_ceil_div(effective_cus, hardware.NUM_XCD), static_cast(1)); // Initial guess for the L2 tile dimensions (a tile of workgroups). - size_t l2_tile_n = std::min(cu_per_xcd, std::min(static_cast(config.workgroup_mapping), workgroups_n)); + size_t l2_tile_n = std::min(static_cast(config.workgroup_mapping), workgroups_n); size_t l2_tile_m = math::safe_ceil_div(cu_per_xcd, l2_tile_n); // Handle wrap-around case: if the tile is taller than the grid, wrap it to be wider. @@ -400,7 +400,6 @@ double estimate_l2_hit(const problem_t& problem, auto calculate_footprint = [&](auto tile_m, auto tile_n) { auto a_footprint = tile_m * config.mt.mk() * a_bytes; auto b_footprint = tile_n * config.mt.nk() * b_bytes; - return (a_footprint + b_footprint) * batches_per_xcd; return (a_footprint + b_footprint) * tiles_per_xcd; }; @@ -513,6 +512,7 @@ double compute_l2_hit_rate_global(const problem_t& problem, const hardware_t& hardware, const config_t& config, size_t l2_capacity_bytes, + size_t num_active_cus, size_t splitting_factor) { // --- Hardware Parameters (as requested, defined locally) --- // You would normally get l2_capacity_bytes from your hardware_t struct. @@ -532,7 +532,7 @@ double compute_l2_hit_rate_global(const problem_t& problem, const double a_working_set = static_cast(grid_m * config.mt.mk()) * a_bytes; const double b_working_set = static_cast(grid_n * config.mt.nk()) * b_bytes; - const double concurrent_tiles = std::min(problem.batch * splitting_factor, std::max(hardware.N_CU / (grid_m * grid_n), static_cast(1))); + const double concurrent_tiles = std::min(problem.batch * splitting_factor, std::max(num_active_cus / (grid_m * grid_n), static_cast(1))); const double total_working_set_bytes = (a_working_set + b_working_set) * concurrent_tiles; // 3. CRUCIAL: Check if the working set fits in the L2 cache. @@ -600,11 +600,11 @@ double compute_memory_latency(const problem_t& problem, // (Assumes capacity is given in KiB, convert to bytes) // TODO hardware.L2_capacity is already in bytes? double H_mem1_global = - compute_l2_hit_rate_global(problem, hardware, config, hardware.L2_capacity * 1024, splitting_factor); + compute_l2_hit_rate_global(problem, hardware, config, hardware.L2_capacity * 1024, num_active_cus, splitting_factor); H_mem1 = std::min(H_mem1, H_mem1_global); - // if (H_mem1 == 0) { H_mem1 = 0.5; } + if (H_mem1 == 0) { H_mem1 = 0.1; } // 2) Estimate mall hit-rate double H_mem2 = estimate_mall_hit(problem, hardware, config, num_active_cus, splitting_factor); @@ -653,7 +653,6 @@ double compute_memory_latency(const problem_t& problem, size_t grid_m = math::safe_ceil_div(problem.size.m, MT_M); size_t grid_n = math::safe_ceil_div(problem.size.n, MT_N); size_t mall_n = std::min(static_cast(config.workgroup_mapping), grid_n); - // size_t mall_m = math::safe_ceil_div(num_active_cus, mall_n); size_t mall_m = math::safe_ceil_div(std::min(num_active_cus, grid_m * grid_n), mall_n); // Handle wrap-around case if (mall_m > grid_m) { @@ -675,10 +674,10 @@ double compute_memory_latency(const problem_t& problem, } } double concurrent_tiles = std::min(static_cast(problem.batch * splitting_factor), - std::max(static_cast(num_active_cus) / (mall_m * mall_n), 1.)); + std::max(static_cast(num_active_cus) / (grid_m * grid_n), 1.)); // This is the minimum unique bytes needed from HBM to feed the concurrent workgroups. - double min_load = static_cast((mall_m * Ld_A_value * static_cast(a_bytes)) + - (mall_n * Ld_B_value * static_cast(b_bytes))) * + double min_load = static_cast((mall_m * Ld_A_value * a_bytes) + + (mall_n * Ld_B_value * b_bytes)) * concurrent_tiles; // Apply batching to the minimum load itself. // The actual loads cannot be less than this physical minimum. Ld_MEM = std::max(Ld_MEM, min_load); @@ -909,6 +908,11 @@ double compute_total_latency(const problem_t& problem, // 1-1) To compute the latency, use default WGM. And WGM can't be greater than one int defaultWGM = batch > 1 ? 1 : static_cast(ceil(std::sqrt(max_cus / hardware.NUM_XCD))); + const size_t grid_m = math::safe_ceil_div(M, MT_N); + const size_t grid_n = math::safe_ceil_div(N, MT_N); + if(grid_m * grid_n > hardware.N_CU && M > 10 * N && grid_n <= 8) { + defaultWGM = grid_n; + } auto config_with_default_wgm = config; config_with_default_wgm.workgroup_mapping = std::max(defaultWGM, 1); From cf1952ee0cc8c239686c515934f06f15aa12d747 Mon Sep 17 00:00:00 2001 From: Pieter Ghysels Date: Fri, 19 Dec 2025 19:16:48 -0600 Subject: [PATCH 6/7] Simplify changes --- shared/origami/src/origami/gemm.cpp | 64 ++++++++--------------------- 1 file changed, 16 insertions(+), 48 deletions(-) diff --git a/shared/origami/src/origami/gemm.cpp b/shared/origami/src/origami/gemm.cpp index 7d3f76fdc0a..ebff0eb84c8 100644 --- a/shared/origami/src/origami/gemm.cpp +++ b/shared/origami/src/origami/gemm.cpp @@ -391,22 +391,18 @@ double estimate_l2_hit(const problem_t& problem, l2_tile_m = std::max(std::min(workgroups_m, l2_tile_m), static_cast(1)); l2_tile_n = std::max(std::min(workgroups_n, l2_tile_n), static_cast(1)); - double tiles_per_xcd = std::min(static_cast(problem.batch * splitting_factor), - std::max(static_cast(hardware.CU_per_L2) / cu_per_xcd, 1.)); - // Calculate memory footprint in bytes. const auto a_bytes = data_type_to_bytes(problem.a_dtype); const auto b_bytes = data_type_to_bytes(problem.b_dtype); auto calculate_footprint = [&](auto tile_m, auto tile_n) { auto a_footprint = tile_m * config.mt.mk() * a_bytes; auto b_footprint = tile_n * config.mt.nk() * b_bytes; - return (a_footprint + b_footprint) * tiles_per_xcd; + return a_footprint + b_footprint; }; // Symmetrically shrink the L2 tile until it fits in the L2 cache capacity. // This is more robust than shrinking only one dimension. - while (calculate_footprint(l2_tile_m, l2_tile_n) > hardware.L2_capacity - || l2_tile_m * l2_tile_n > cu_per_xcd) { + while (calculate_footprint(l2_tile_m, l2_tile_n) > hardware.L2_capacity) { if (l2_tile_m > 1 && l2_tile_m >= l2_tile_n) { l2_tile_m--; } else if (l2_tile_n > 1) { @@ -451,8 +447,9 @@ double estimate_mall_hit(const problem_t& problem, // --- Initial Tile Sizing based on Concurrency --- // Use ceiling division for a more accurate initial guess. + size_t mall_tile_m = + math::safe_ceil_div(num_active_cus, static_cast(config.workgroup_mapping)); size_t mall_tile_n = std::min(static_cast(config.workgroup_mapping), workgroups_n); - size_t mall_tile_m = math::safe_ceil_div(std::min(num_active_cus, workgroups_m * workgroups_n), mall_tile_n); // Handle wrap-around case if the tile is taller than the grid. if (mall_tile_m > workgroups_m) { @@ -475,17 +472,6 @@ double estimate_mall_hit(const problem_t& problem, return a_footprint + b_footprint; }; - while (mall_tile_m * mall_tile_n > num_active_cus) { - if (mall_tile_m > 1 && mall_tile_m >= mall_tile_n) { - mall_tile_m--; - } else if (mall_tile_n > 1) { - mall_tile_n--; - } else { - // Cannot shrink further. - break; - } - } - // --- Calculate Hit Rate based on the final, capacity-aware tile size --- const long long uncached_A_reads = static_cast(mall_tile_m) * config.mt.mk(); const long long uncached_B_reads = static_cast(mall_tile_n) * config.mt.nk(); @@ -511,9 +497,7 @@ double estimate_mall_hit(const problem_t& problem, double compute_l2_hit_rate_global(const problem_t& problem, const hardware_t& hardware, const config_t& config, - size_t l2_capacity_bytes, - size_t num_active_cus, - size_t splitting_factor) { + size_t l2_capacity_bytes) { // --- Hardware Parameters (as requested, defined locally) --- // You would normally get l2_capacity_bytes from your hardware_t struct. if (l2_capacity_bytes == 0) throw std::runtime_error("L2 Capacity is zero"); @@ -532,8 +516,7 @@ double compute_l2_hit_rate_global(const problem_t& problem, const double a_working_set = static_cast(grid_m * config.mt.mk()) * a_bytes; const double b_working_set = static_cast(grid_n * config.mt.nk()) * b_bytes; - const double concurrent_tiles = std::min(problem.batch * splitting_factor, std::max(num_active_cus / (grid_m * grid_n), static_cast(1))); - const double total_working_set_bytes = (a_working_set + b_working_set) * concurrent_tiles; + const double total_working_set_bytes = a_working_set + b_working_set; // 3. CRUCIAL: Check if the working set fits in the L2 cache. // If it doesn't, the global reuse pattern is broken by capacity misses, @@ -598,13 +581,12 @@ double compute_memory_latency(const problem_t& problem, // Global cap on L2 hit-rate (prevents impossible cache residency claims) // (Assumes capacity is given in KiB, convert to bytes) - // TODO hardware.L2_capacity is already in bytes? double H_mem1_global = - compute_l2_hit_rate_global(problem, hardware, config, hardware.L2_capacity * 1024, num_active_cus, splitting_factor); + compute_l2_hit_rate_global(problem, hardware, config, hardware.L2_capacity * 1024); H_mem1 = std::min(H_mem1, H_mem1_global); - if (H_mem1 == 0) { H_mem1 = 0.1; } + if (H_mem1 == 0) { H_mem1 = 0.5; } // 2) Estimate mall hit-rate double H_mem2 = estimate_mall_hit(problem, hardware, config, num_active_cus, splitting_factor); @@ -652,8 +634,9 @@ double compute_memory_latency(const problem_t& problem, // Calculate the tile of workgroups that can run concurrently (logic from estimate_mall_hit). size_t grid_m = math::safe_ceil_div(problem.size.m, MT_M); size_t grid_n = math::safe_ceil_div(problem.size.n, MT_N); + size_t mall_m = + math::safe_ceil_div(num_active_cus, static_cast(config.workgroup_mapping)); size_t mall_n = std::min(static_cast(config.workgroup_mapping), grid_n); - size_t mall_m = math::safe_ceil_div(std::min(num_active_cus, grid_m * grid_n), mall_n); // Handle wrap-around case if (mall_m > grid_m) { size_t num_wraps = (mall_m / grid_m); @@ -663,22 +646,12 @@ double compute_memory_latency(const problem_t& problem, // Clamp tile dimensions mall_m = std::max(std::min(grid_m, mall_m), static_cast(1)); mall_n = std::max(std::min(grid_n, mall_n), static_cast(1)); - while (mall_m * mall_n > num_active_cus) { - if (mall_m > 1 && mall_m >= mall_n) { - mall_m--; - } else if (mall_n > 1) { - mall_n--; - } else { - // Cannot shrink further. - break; - } - } - double concurrent_tiles = std::min(static_cast(problem.batch * splitting_factor), - std::max(static_cast(num_active_cus) / (grid_m * grid_n), 1.)); // This is the minimum unique bytes needed from HBM to feed the concurrent workgroups. - double min_load = static_cast((mall_m * Ld_A_value * a_bytes) + - (mall_n * Ld_B_value * b_bytes)) * - concurrent_tiles; // Apply batching to the minimum load itself. + double concurrent_batches = std::min(static_cast(problem.batch), + std::max(static_cast(num_active_cus) / (grid_m * grid_n), 1.)); + double min_load = static_cast((mall_m * config.mt.mk() * a_bytes) + + (mall_n * config.mt.nk() * b_bytes)) * + concurrent_batches; // Apply batching to the minimum load itself. // The actual loads cannot be less than this physical minimum. Ld_MEM = std::max(Ld_MEM, min_load); Ld_mem2 = std::max(Ld_mem2, min_load); @@ -907,12 +880,7 @@ double compute_total_latency(const problem_t& problem, } // 1-1) To compute the latency, use default WGM. And WGM can't be greater than one - int defaultWGM = batch > 1 ? 1 : static_cast(ceil(std::sqrt(max_cus / hardware.NUM_XCD))); - const size_t grid_m = math::safe_ceil_div(M, MT_N); - const size_t grid_n = math::safe_ceil_div(N, MT_N); - if(grid_m * grid_n > hardware.N_CU && M > 10 * N && grid_n <= 8) { - defaultWGM = grid_n; - } + int defaultWGM = batch > 1 ? 1 : static_cast(ceil(std::sqrt(hardware.N_CU / hardware.NUM_XCD))); auto config_with_default_wgm = config; config_with_default_wgm.workgroup_mapping = std::max(defaultWGM, 1); From fad7069ccd5f935d5a70341df16ff650d5844864 Mon Sep 17 00:00:00 2001 From: Pieter Ghysels Date: Fri, 19 Dec 2025 19:41:12 -0600 Subject: [PATCH 7/7] Remove heuristics, rounding of MT --- shared/origami/src/origami/gemm.cpp | 37 +++++++++++++---------------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/shared/origami/src/origami/gemm.cpp b/shared/origami/src/origami/gemm.cpp index ebff0eb84c8..a89fc8f1e26 100644 --- a/shared/origami/src/origami/gemm.cpp +++ b/shared/origami/src/origami/gemm.cpp @@ -592,10 +592,22 @@ double compute_memory_latency(const problem_t& problem, double H_mem2 = estimate_mall_hit(problem, hardware, config, num_active_cus, splitting_factor); // 3) Total loads are loads from A and loads from B - size_t Ld_A_value = a_trans ? - MT_M * round_elements_to_128B(MT_K, a_bits) : round_elements_to_128B(MT_M, a_bits) * MT_K; - size_t Ld_B_value = b_trans ? - round_elements_to_128B(MT_N, b_bits) * MT_K : MT_N * round_elements_to_128B(MT_K, b_bits); + size_t MT_M_rounded_128bytes = round_elements_to_128B(MT_M, a_bits); + size_t MT_N_rounded_128bytes = round_elements_to_128B(MT_N, a_bits); + size_t MT_K_rounded_128bytes = round_elements_to_128B(MT_K, a_bits); + + if (!a_trans && !b_trans) { + MT_N_rounded_128bytes = MT_N; + MT_K_rounded_128bytes = MT_K; + } else if (a_trans && !b_trans) { + MT_M_rounded_128bytes = MT_M; + MT_N_rounded_128bytes = MT_N; + } else if (!a_trans && b_trans) { + MT_K_rounded_128bytes = MT_K; + } + + size_t Ld_A_value = MT_M_rounded_128bytes * MT_K_rounded_128bytes; + size_t Ld_B_value = MT_N_rounded_128bytes * MT_K_rounded_128bytes; auto Ld_CU_bytes = (Ld_A_value * a_bytes) // A Bytes + (Ld_B_value * b_bytes); // B Bytes @@ -906,23 +918,6 @@ double compute_total_latency(const problem_t& problem, total_latency = total_latency * 10; } - if(hardware.arch == hardware_t::architecture_t::gfx950) { - // TODO is Origami missing something for the "Large N, very small M and K" case? - // In BBS NN gfx950, MT192x336x32 performs poorly on "Large N, very small M and K". - if(MT_M == 192 && MT_N == 336 && MT_K == 32 && !a_trans && !b_trans && problem.a_dtype == data_type_t::BFloat16) { - total_latency = total_latency * 10; - } - // In HHS TN gfx950, MT128x512x32 performs poorly on "Large N, very small M and K". - if(MT_M == 128 && MT_N == 512 && MT_K == 32 && a_trans && !b_trans && problem.a_dtype == data_type_t::Half) { - total_latency = total_latency * 10; - } - // In HHS NT gfx950, DP kernel with MT16x16x64 was added as a fallback kernel (VWA1 and VWB1), - // but it performs very poorly, especially on "Large K, very small M and N". - if(MT_M == 16 && MT_N == 16 && MT_K == 64 && !a_trans && b_trans && problem.a_dtype == data_type_t::Half) { - total_latency = total_latency * 10; - } - } - bool tf32_emu = ((problem.mi_dtype == data_type_t::XFloat32) && (hardware.arch == hardware_t::architecture_t::gfx950));