From f8d43bc9e87ad5f8e72301a8e797bf5f881e63f4 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Tue, 20 Dec 2022 15:40:57 -0500 Subject: [PATCH] refactor: migrate bs finish_example (#4366) --- vowpalwabbit/core/src/reductions/bs.cc | 63 ++++++++++---------------- 1 file changed, 24 insertions(+), 39 deletions(-) diff --git a/vowpalwabbit/core/src/reductions/bs.cc b/vowpalwabbit/core/src/reductions/bs.cc index 2bf7f4d5fb5..5636c332c89 100644 --- a/vowpalwabbit/core/src/reductions/bs.cc +++ b/vowpalwabbit/core/src/reductions/bs.cc @@ -15,6 +15,7 @@ #include "vw/io/errno_handling.h" #include "vw/io/logger.h" +#include #include #include #include @@ -27,6 +28,8 @@ using namespace VW::LEARNER; using namespace VW::config; using namespace VW::reductions; +namespace +{ #define BS_TYPE_MEAN 0 #define BS_TYPE_VOTE 1 @@ -35,27 +38,26 @@ class bs_data public: uint32_t num_bootstrap_rounds = 0; // number of bootstrap rounds size_t bs_type = 0; - float lb = 0.f; - float ub = 0.f; std::vector pred_vec; VW::workspace* all = nullptr; // for raw prediction and loss std::shared_ptr random_state; }; -void bs_predict_mean(VW::workspace& all, VW::example& ec, std::vector& pred_vec) +void bs_predict_mean(const VW::workspace& all, VW::example& ec, const std::vector& pred_vec) { - ec.pred.scalar = static_cast(accumulate(pred_vec.cbegin(), pred_vec.cend(), 0.0)) / pred_vec.size(); + ec.pred.scalar = static_cast(std::accumulate(pred_vec.cbegin(), pred_vec.cend(), 0.0)) / pred_vec.size(); if (ec.weight > 0 && ec.l.simple.label != FLT_MAX) { ec.loss = all.loss->get_loss(all.sd, ec.pred.scalar, ec.l.simple.label) * ec.weight; } } -void bs_predict_vote(VW::example& ec, std::vector& pred_vec) +void bs_predict_vote(VW::example& ec, const std::vector& pred_vec) { // majority vote in linear time unsigned int counter = 0; - int current_label = 1, init_label = 1; + int current_label = 1; + int init_label = 1; // float sum_labels = 0; // uncomment for: "avg on votes" and get_loss() bool majority_found = false; bool multivote_detected = false; // distinct(votes)>2: used to skip part of the algorithm @@ -63,8 +65,8 @@ void bs_predict_vote(VW::example& ec, std::vector& pred_vec) for (size_t i = 0; i < pred_vec_int.size(); i++) { - pred_vec_int[i] = static_cast( - floor(pred_vec[i] + 0.5)); // could be added: link(), min_label/max_label, cutoff between true/false for binary + pred_vec_int[i] = static_cast(std::floor( + pred_vec[i] + 0.5)); // could be added: link(), min_label/max_label, cutoff between true/false for binary if (!multivote_detected) // distinct(votes)>2 detection bloc { @@ -138,45 +140,33 @@ void bs_predict_vote(VW::example& ec, std::vector& pred_vec) ec.loss = ((ec.pred.scalar == ec.l.simple.label) ? 0.f : 1.f) * ec.weight; } -void print_result( - VW::io::writer* f, float res, const VW::v_array& tag, float lb, float ub, VW::io::logger& logger) +void print_result(VW::io::writer* f, float res, const VW::v_array& tag, float lower_bound, float upper_bound, + VW::io::logger& logger) { if (f == nullptr) { return; } std::stringstream ss; ss << std::fixed << res; if (!tag.empty()) { ss << " " << VW::string_view{tag.begin(), tag.size()}; } - ss << std::fixed << ' ' << lb << ' ' << ub << '\n'; + ss << std::fixed << ' ' << lower_bound << ' ' << upper_bound << '\n'; const auto ss_str = ss.str(); ssize_t len = ss_str.size(); ssize_t t = f->write(ss_str.c_str(), static_cast(len)); if (t != len) { logger.err_error("write error: {}", VW::io::strerror_to_string(errno)); } } -void output_example(VW::workspace& all, bs_data& d, const VW::example& ec) +void output_example_prediction_bs( + VW::workspace& all, const bs_data& data, const VW::example& ec, VW::io::logger& logger) { - const auto& ld = ec.l.simple; - - all.sd->update(ec.test_only, ld.label != FLT_MAX, ec.loss, ec.weight, ec.get_num_features()); - if (ld.label != FLT_MAX && !ec.test_only) { all.sd->weighted_labels += (static_cast(ld.label)) * ec.weight; } - - if (!all.final_prediction_sink.empty()) // get confidence interval only when printing out predictions + if (!all.final_prediction_sink.empty()) { - d.lb = FLT_MAX; - d.ub = -FLT_MAX; - for (double v : d.pred_vec) + // get confidence interval only when printing out predictions + const auto min_max = std::minmax_element(data.pred_vec.begin(), data.pred_vec.end()); + for (auto& sink : all.final_prediction_sink) { - if (v > d.ub) { d.ub = static_cast(v); } - if (v < d.lb) { d.lb = static_cast(v); } + print_result(sink.get(), ec.pred.scalar, ec.tag, *min_max.first, *min_max.second, logger); } } - - for (auto& sink : all.final_prediction_sink) - { - print_result(sink.get(), ec.pred.scalar, ec.tag, d.lb, d.ub, all.logger); - } - - VW::details::print_update(all, ec); } template @@ -225,12 +215,7 @@ void predict_or_learn(bs_data& d, single_learner& base, VW::example& ec) all.print_text_by_ref(all.raw_prediction.get(), output_string_stream.str(), ec.tag, all.logger); } } - -void finish_example(VW::workspace& all, bs_data& d, VW::example& ec) -{ - output_example(all, d, ec); - VW::finish_example(all, ec); -} +} // namespace base_learner* VW::reductions::bs_setup(VW::setup_base_i& stack_builder) { @@ -252,8 +237,6 @@ base_learner* VW::reductions::bs_setup(VW::setup_base_i& stack_builder) if (!options.add_parse_and_check_necessary(new_options)) { return nullptr; } size_t ws = data->num_bootstrap_rounds; - data->ub = FLT_MAX; - data->lb = -FLT_MAX; if (options.was_supplied("bs_type")) { @@ -278,7 +261,9 @@ base_learner* VW::reductions::bs_setup(VW::setup_base_i& stack_builder) predict_or_learn, predict_or_learn, stack_builder.get_setupfn_name(bs_setup)) .set_params_per_weight(ws) .set_learn_returns_prediction(true) - .set_finish_example(::finish_example) + .set_output_example_prediction(output_example_prediction_bs) + .set_update_stats(VW::details::update_stats_simple_label) + .set_print_update(VW::details::print_update_simple_label) .set_input_label_type(VW::label_type_t::SIMPLE) .set_output_prediction_type(VW::prediction_type_t::SCALAR) .build();