Skip to content

Commit 8493e78

Browse files
committed
Benchmark cleanups
Signed-off-by: djns99 <[email protected]>
1 parent 79888c4 commit 8493e78

File tree

1 file changed

+24
-29
lines changed

1 file changed

+24
-29
lines changed

cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkLauncher.cu

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -273,35 +273,15 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark)
273273
}
274274

275275
// Do this after filtering datatypes as tactics only make sense if we know the data type
276-
bool has_tactic_ids2 = false;
277276
std::vector<int> tactic_ids1{};
278277
std::vector<int> tactic_ids2{};
279-
if (run_config.contains("tactic_id1") || run_config.contains("tactic_id2"))
278+
if (run_config.contains("tactic_id1"))
280279
{
281-
has_tactic_ids2 = true;
282280
parseTacticToVectorID<BenchClass>(run_config["tactic_id1"], tactic_ids1, MoeGemmId::GEMM_1);
283-
parseTacticToVectorID<BenchClass>(run_config["tactic_id2"], tactic_ids2, MoeGemmId::GEMM_2);
284281
}
285-
286-
if (tactic_ids1.empty() || tactic_ids2.empty())
282+
if (run_config.contains("tactic_id2"))
287283
{
288-
std::cerr << "Warning: Skipping benchmark, no valid tactic found" << std::endl;
289-
static bool printed = false;
290-
if (!printed)
291-
{
292-
printed = true;
293-
std::cerr << __PRETTY_FUNCTION__ << ": Valid Tactics are:\n";
294-
for (auto gemm_id : {MoeGemmId::GEMM_1, MoeGemmId::GEMM_2})
295-
{
296-
std::cerr << "GEMM " << (int) gemm_id << ":\n";
297-
auto confs = listAllTactics<BenchClass>(gemm_id);
298-
for (auto c : confs)
299-
std::cerr << c.toString();
300-
std::cerr << std::endl;
301-
}
302-
}
303-
304-
continue;
284+
parseTacticToVectorID<BenchClass>(run_config["tactic_id2"], tactic_ids2, MoeGemmId::GEMM_2);
305285
}
306286

307287
auto get_or = [&](auto name, auto def)
@@ -337,8 +317,6 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark)
337317
}
338318
else if (gemm_to_profile == (int) GemmToProfile::GEMM_2)
339319
{
340-
if (!has_tactic_ids2)
341-
tactic_ids2 = std::move(tactic_ids1);
342320
tactic_ids1 = {-1};
343321
}
344322
}
@@ -353,14 +331,31 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark)
353331
return val;
354332
};
355333

334+
if (tactic_ids1.empty() || tactic_ids2.empty())
335+
{
336+
std::cerr << "Warning: Skipping benchmark, no valid tactic found" << std::endl;
337+
static bool printed = false;
338+
if (!printed)
339+
{
340+
printed = true;
341+
std::cerr << __PRETTY_FUNCTION__ << ": Valid Tactics are:\n";
342+
for (auto gemm_id : {MoeGemmId::GEMM_1, MoeGemmId::GEMM_2})
343+
{
344+
std::cerr << "GEMM " << (int) gemm_id << ":\n";
345+
auto confs = listAllTactics<BenchClass>(gemm_id);
346+
for (auto c : confs)
347+
std::cerr << c.toString();
348+
std::cerr << std::endl;
349+
}
350+
}
351+
352+
continue;
353+
}
354+
356355
for (auto t1 : tactic_ids1)
357356
{
358-
// tactic_ids2 will have one dummy value if has_tactic_ids2 = false
359357
for (auto t2 : tactic_ids2)
360358
{
361-
if (!has_tactic_ids2)
362-
t2 = t1;
363-
364359
benchmark->Args({num_experts, //
365360
get_range("k"), //
366361
get_range("hidden_size"), //

0 commit comments

Comments
 (0)