diff --git a/shared/origami/src/origami/gemm.cpp b/shared/origami/src/origami/gemm.cpp index 2404ced4254..a89fc8f1e26 100644 --- a/shared/origami/src/origami/gemm.cpp +++ b/shared/origami/src/origami/gemm.cpp @@ -372,7 +372,7 @@ double estimate_l2_hit(const problem_t& problem, // Number of CUs that might share the same K-tiles, adjusted for K-splitting. // This affects contention on the L2 cache partitions (XCDs). - const size_t effective_cus = math::safe_ceil_div(concurrent_workgroups, splitting_factor); + const size_t effective_cus = math::safe_ceil_div(concurrent_workgroups, splitting_factor * problem.batch); const size_t cu_per_xcd = std::max(math::safe_ceil_div(effective_cus, hardware.NUM_XCD), static_cast(1)); @@ -659,9 +659,11 @@ double compute_memory_latency(const problem_t& problem, mall_m = std::max(std::min(grid_m, mall_m), static_cast(1)); mall_n = std::max(std::min(grid_n, mall_n), static_cast(1)); // This is the minimum unique bytes needed from HBM to feed the concurrent workgroups. - double min_load = static_cast((mall_m * config.mt.mk() * static_cast(a_bytes)) + - (mall_n * config.mt.nk() * static_cast(b_bytes))) * - batch; // Apply batching to the minimum load itself. + double concurrent_batches = std::min(static_cast(problem.batch), + std::max(static_cast(num_active_cus) / (grid_m * grid_n), 1.)); + double min_load = static_cast((mall_m * config.mt.mk() * a_bytes) + + (mall_n * config.mt.nk() * b_bytes)) * + concurrent_batches; // Apply batching to the minimum load itself. // The actual loads cannot be less than this physical minimum. Ld_MEM = std::max(Ld_MEM, min_load); Ld_mem2 = std::max(Ld_mem2, min_load); @@ -890,7 +892,7 @@ double compute_total_latency(const problem_t& problem, } // 1-1) To compute the latency, use default WGM. And WGM can't be greater than one - int defaultWGM = static_cast(ceil(std::sqrt(hardware.N_CU / hardware.NUM_XCD))); + int defaultWGM = batch > 1 ? 1 : static_cast(ceil(std::sqrt(hardware.N_CU / hardware.NUM_XCD))); auto config_with_default_wgm = config; config_with_default_wgm.workgroup_mapping = std::max(defaultWGM, 1);