Skip to content

Commit

Permalink
refactor: migrate bs finish_example (#4366)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits authored Dec 20, 2022
1 parent b6f6c15 commit f8d43bc
Showing 1 changed file with 24 additions and 39 deletions.
63 changes: 24 additions & 39 deletions vowpalwabbit/core/src/reductions/bs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "vw/io/errno_handling.h"
#include "vw/io/logger.h"

#include <algorithm>
#include <cerrno>
#include <cfloat>
#include <cmath>
Expand All @@ -27,6 +28,8 @@ using namespace VW::LEARNER;
using namespace VW::config;
using namespace VW::reductions;

namespace
{
#define BS_TYPE_MEAN 0
#define BS_TYPE_VOTE 1

Expand All @@ -35,36 +38,35 @@ class bs_data
public:
uint32_t num_bootstrap_rounds = 0; // number of bootstrap rounds
size_t bs_type = 0;
float lb = 0.f;
float ub = 0.f;
std::vector<double> pred_vec;
VW::workspace* all = nullptr; // for raw prediction and loss
std::shared_ptr<VW::rand_state> random_state;
};

void bs_predict_mean(VW::workspace& all, VW::example& ec, std::vector<double>& pred_vec)
void bs_predict_mean(const VW::workspace& all, VW::example& ec, const std::vector<double>& pred_vec)
{
ec.pred.scalar = static_cast<float>(accumulate(pred_vec.cbegin(), pred_vec.cend(), 0.0)) / pred_vec.size();
ec.pred.scalar = static_cast<float>(std::accumulate(pred_vec.cbegin(), pred_vec.cend(), 0.0)) / pred_vec.size();
if (ec.weight > 0 && ec.l.simple.label != FLT_MAX)
{
ec.loss = all.loss->get_loss(all.sd, ec.pred.scalar, ec.l.simple.label) * ec.weight;
}
}

void bs_predict_vote(VW::example& ec, std::vector<double>& pred_vec)
void bs_predict_vote(VW::example& ec, const std::vector<double>& pred_vec)
{
// majority vote in linear time
unsigned int counter = 0;
int current_label = 1, init_label = 1;
int current_label = 1;
int init_label = 1;
// float sum_labels = 0; // uncomment for: "avg on votes" and get_loss()
bool majority_found = false;
bool multivote_detected = false; // distinct(votes)>2: used to skip part of the algorithm
std::vector<int> pred_vec_int(pred_vec.size(), 0);

for (size_t i = 0; i < pred_vec_int.size(); i++)
{
pred_vec_int[i] = static_cast<int>(
floor(pred_vec[i] + 0.5)); // could be added: link(), min_label/max_label, cutoff between true/false for binary
pred_vec_int[i] = static_cast<int>(std::floor(
pred_vec[i] + 0.5)); // could be added: link(), min_label/max_label, cutoff between true/false for binary

if (!multivote_detected) // distinct(votes)>2 detection bloc
{
Expand Down Expand Up @@ -138,45 +140,33 @@ void bs_predict_vote(VW::example& ec, std::vector<double>& pred_vec)
ec.loss = ((ec.pred.scalar == ec.l.simple.label) ? 0.f : 1.f) * ec.weight;
}

void print_result(
VW::io::writer* f, float res, const VW::v_array<char>& tag, float lb, float ub, VW::io::logger& logger)
void print_result(VW::io::writer* f, float res, const VW::v_array<char>& tag, float lower_bound, float upper_bound,
VW::io::logger& logger)
{
if (f == nullptr) { return; }

std::stringstream ss;
ss << std::fixed << res;
if (!tag.empty()) { ss << " " << VW::string_view{tag.begin(), tag.size()}; }
ss << std::fixed << ' ' << lb << ' ' << ub << '\n';
ss << std::fixed << ' ' << lower_bound << ' ' << upper_bound << '\n';
const auto ss_str = ss.str();
ssize_t len = ss_str.size();
ssize_t t = f->write(ss_str.c_str(), static_cast<unsigned int>(len));
if (t != len) { logger.err_error("write error: {}", VW::io::strerror_to_string(errno)); }
}

void output_example(VW::workspace& all, bs_data& d, const VW::example& ec)
void output_example_prediction_bs(
VW::workspace& all, const bs_data& data, const VW::example& ec, VW::io::logger& logger)
{
const auto& ld = ec.l.simple;

all.sd->update(ec.test_only, ld.label != FLT_MAX, ec.loss, ec.weight, ec.get_num_features());
if (ld.label != FLT_MAX && !ec.test_only) { all.sd->weighted_labels += (static_cast<double>(ld.label)) * ec.weight; }

if (!all.final_prediction_sink.empty()) // get confidence interval only when printing out predictions
if (!all.final_prediction_sink.empty())
{
d.lb = FLT_MAX;
d.ub = -FLT_MAX;
for (double v : d.pred_vec)
// get confidence interval only when printing out predictions
const auto min_max = std::minmax_element(data.pred_vec.begin(), data.pred_vec.end());
for (auto& sink : all.final_prediction_sink)
{
if (v > d.ub) { d.ub = static_cast<float>(v); }
if (v < d.lb) { d.lb = static_cast<float>(v); }
print_result(sink.get(), ec.pred.scalar, ec.tag, *min_max.first, *min_max.second, logger);
}
}

for (auto& sink : all.final_prediction_sink)
{
print_result(sink.get(), ec.pred.scalar, ec.tag, d.lb, d.ub, all.logger);
}

VW::details::print_update(all, ec);
}

template <bool is_learn>
Expand Down Expand Up @@ -225,12 +215,7 @@ void predict_or_learn(bs_data& d, single_learner& base, VW::example& ec)
all.print_text_by_ref(all.raw_prediction.get(), output_string_stream.str(), ec.tag, all.logger);
}
}

void finish_example(VW::workspace& all, bs_data& d, VW::example& ec)
{
output_example(all, d, ec);
VW::finish_example(all, ec);
}
} // namespace

base_learner* VW::reductions::bs_setup(VW::setup_base_i& stack_builder)
{
Expand All @@ -252,8 +237,6 @@ base_learner* VW::reductions::bs_setup(VW::setup_base_i& stack_builder)

if (!options.add_parse_and_check_necessary(new_options)) { return nullptr; }
size_t ws = data->num_bootstrap_rounds;
data->ub = FLT_MAX;
data->lb = -FLT_MAX;

if (options.was_supplied("bs_type"))
{
Expand All @@ -278,7 +261,9 @@ base_learner* VW::reductions::bs_setup(VW::setup_base_i& stack_builder)
predict_or_learn<true>, predict_or_learn<false>, stack_builder.get_setupfn_name(bs_setup))
.set_params_per_weight(ws)
.set_learn_returns_prediction(true)
.set_finish_example(::finish_example)
.set_output_example_prediction(output_example_prediction_bs)
.set_update_stats(VW::details::update_stats_simple_label<bs_data>)
.set_print_update(VW::details::print_update_simple_label<bs_data>)
.set_input_label_type(VW::label_type_t::SIMPLE)
.set_output_prediction_type(VW::prediction_type_t::SCALAR)
.build();
Expand Down

1 comment on commit f8d43bc

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: f8d43bc Previous: b6f6c15 Ratio
BenchmarkLearnSimple.Benchmark(args: 1_feature) 8169.662921348315 ns (± 1577.6424078050777) 779.5062798720139 ns (± 1.2837945626837122) 10.48

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.