1
+ #include " ../../vowpalwabbit/core/tests/simulator.h"
2
+ #include " benchmarks_common.h"
3
+ #include " vw/cache_parser/parse_example_cache.h"
4
+ #include " vw/config/options_cli.h"
5
+ #include " vw/core/learner.h"
6
+ #include " vw/core/metric_sink.h"
7
+ #include " vw/core/parser.h"
8
+ #include " vw/core/reductions/epsilon_decay.h"
9
+ #include " vw/core/setup_base.h"
10
+ #include " vw/core/vw.h"
11
+ #include " vw/io/io_adapter.h"
12
+ #include " vw/text_parser/parse_example_text.h"
13
+
14
+ #include < benchmark/benchmark.h>
15
+
16
+ template <class ... ExtraArgs>
17
+ static void bench_epsilon_decay (benchmark::State& state, bool use_decay, ExtraArgs&&... extra_args)
18
+ {
19
+ std::array<std::string, sizeof ...(extra_args)> res = {extra_args...};
20
+ std::string model_count = res[0 ];
21
+ std::string bit_size = res[1 ];
22
+ std::string tolerance = res[2 ];
23
+
24
+ using callback_map =
25
+ typename std::map<size_t , std::function<bool (simulator::cb_sim&, VW::workspace&, VW::multi_ex&)>>;
26
+ callback_map test_hooks;
27
+
28
+ for (auto _ : state)
29
+ {
30
+ const size_t num_iterations = 1000 ;
31
+ const size_t seed = 99 ;
32
+ const std::vector<uint64_t > swap_after = {500 };
33
+ if (use_decay)
34
+ {
35
+ simulator::_test_helper_hook (
36
+ std::vector<std::string>{" -l" , " 1e-3" , " --power_t" , " 0" , " -q::" , " --cb_explore_adf" , " --epsilon_decay" ,
37
+ " --model_count" , model_count, " -b" , bit_size, " --tol_x" , tolerance, " --quiet" },
38
+ test_hooks, num_iterations, seed, swap_after);
39
+ }
40
+ else
41
+ {
42
+ simulator::_test_helper_hook (
43
+ std::vector<std::string>{" -l" , " 1e-3" , " --power_t" , " 0" , " -q::" , " --cb_explore_adf" , " --quiet" }, test_hooks,
44
+ num_iterations, seed, swap_after);
45
+ }
46
+ benchmark::ClobberMemory ();
47
+ }
48
+ }
49
+
50
+ BENCHMARK_CAPTURE (bench_epsilon_decay, epsilon_decay_1_model_big_tol, true , " 1" , " 18" , " 1e-2" );
51
+ BENCHMARK_CAPTURE (bench_epsilon_decay, epsilon_decay_2_model_big_tol, true , " 2" , " 19" , " 1e-2" );
52
+ BENCHMARK_CAPTURE (bench_epsilon_decay, epsilon_decay_4_model_big_tol, true , " 4" , " 20" , " 1e-2" );
53
+ BENCHMARK_CAPTURE (bench_epsilon_decay, epsilon_decay_1_model_small_tol, true , " 1" , " 18" , " 1e-6" );
54
+ BENCHMARK_CAPTURE (bench_epsilon_decay, epsilon_decay_2_model_small_tol, true , " 2" , " 19" , " 1e-6" );
55
+ BENCHMARK_CAPTURE (bench_epsilon_decay, epsilon_decay_4_model_small_tol, true , " 4" , " 20" , " 1e-6" );
56
+ BENCHMARK_CAPTURE (bench_epsilon_decay, without_epsilon_decay, false , " " , " " , " " );
0 commit comments