Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions shared/origami/src/origami/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ double estimate_l2_hit(const problem_t& problem,

// Number of CUs that might share the same K-tiles, adjusted for K-splitting.
// This affects contention on the L2 cache partitions (XCDs).
const size_t effective_cus = math::safe_ceil_div(concurrent_workgroups, splitting_factor);
const size_t effective_cus = math::safe_ceil_div(concurrent_workgroups, splitting_factor * problem.batch);
const size_t cu_per_xcd =
std::max(math::safe_ceil_div(effective_cus, hardware.NUM_XCD), static_cast<size_t>(1));

Expand Down Expand Up @@ -659,9 +659,11 @@ double compute_memory_latency(const problem_t& problem,
mall_m = std::max(std::min(grid_m, mall_m), static_cast<size_t>(1));
mall_n = std::max(std::min(grid_n, mall_n), static_cast<size_t>(1));
// This is the minimum unique bytes needed from HBM to feed the concurrent workgroups.
double min_load = static_cast<double>((mall_m * config.mt.mk() * static_cast<size_t>(a_bytes)) +
(mall_n * config.mt.nk() * static_cast<size_t>(b_bytes))) *
batch; // Apply batching to the minimum load itself.
double concurrent_batches = std::min(static_cast<double>(problem.batch),
std::max(static_cast<double>(num_active_cus) / (grid_m * grid_n), 1.));
double min_load = static_cast<double>((mall_m * config.mt.mk() * a_bytes) +
(mall_n * config.mt.nk() * b_bytes)) *
concurrent_batches; // Apply batching to the minimum load itself.
// The actual loads cannot be less than this physical minimum.
Ld_MEM = std::max(Ld_MEM, min_load);
Ld_mem2 = std::max(Ld_mem2, min_load);
Expand Down Expand Up @@ -890,7 +892,7 @@ double compute_total_latency(const problem_t& problem,
}

// 1-1) To compute the latency, use default WGM. And WGM can't be greater than one
int defaultWGM = static_cast<int>(ceil(std::sqrt(hardware.N_CU / hardware.NUM_XCD)));
int defaultWGM = batch > 1 ? 1 : static_cast<int>(ceil(std::sqrt(hardware.N_CU / hardware.NUM_XCD)));
auto config_with_default_wgm = config;
config_with_default_wgm.workgroup_mapping = std::max(defaultWGM, 1);

Expand Down
Loading