Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: cb_algs finish functions #4409

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions vowpalwabbit/core/include/vw/core/reductions/cb/cb_algs.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,4 @@ inline bool example_is_newline_not_header(VW::example const& ec)
return (VW::example_is_newline(ec) && !CB::ec_is_example_header(ec));
}

void generic_output_example(
VW::workspace& all, float loss, const VW::example& ec, const CB::label& ld, const CB::cb_class* known_cost);

} // namespace CB_ALGS
86 changes: 46 additions & 40 deletions vowpalwabbit/core/src/reductions/cb/cb_algs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,42 +13,14 @@
#include "vw/io/logger.h"

#include <cfloat>
#include <sstream>

using namespace VW::LEARNER;
using namespace VW::config;

using namespace CB;
using namespace GEN_CS;

namespace CB_ALGS
{
void generic_output_example(
VW::workspace& all, float loss, const VW::example& ec, const CB::label& ld, const CB::cb_class* known_cost)
{
all.sd->update(ec.test_only, !ld.is_test_label(), loss, 1.f, ec.get_num_features());

for (auto& sink : all.final_prediction_sink)
{
all.print_by_ref(sink.get(), static_cast<float>(ec.pred.multiclass), 0, ec.tag, all.logger);
}

if (all.raw_prediction != nullptr)
{
std::stringstream output_string_stream;
for (unsigned int i = 0; i < ld.costs.size(); i++)
{
cb_class cl = ld.costs[i];
if (i > 0) { output_string_stream << ' '; }
output_string_stream << cl.action << ':' << cl.partial_prediction;
}
all.print_text_by_ref(all.raw_prediction.get(), output_string_stream.str(), ec.tag, all.logger);
}

bool is_ld_test_label = ld.is_test_label();
if (!is_ld_test_label) { print_update(all, is_ld_test_label, ec, nullptr, false, known_cost); }
else { print_update(all, is_ld_test_label, ec, nullptr, false, nullptr); }
}
} // namespace CB_ALGS
namespace
{
class cb
Expand Down Expand Up @@ -109,26 +81,53 @@ void learn_eval(cb& data, single_learner&, VW::example& ec)
ec.pred.multiclass = ec.l.cb_eval.action;
}

void output_example(VW::workspace& all, cb& data, const VW::example& ec, const CB::label& ld)
template <bool uses_eval>
void update_stats_cb_algs(const VW::workspace& /* all */, VW::shared_data& sd, const cb& data, const VW::example& ec,
VW::io::logger& /* unused */)
{
const auto& ld = uses_eval ? ec.l.cb_eval.event : ec.l.cb;
const auto& c = data.cbcs;
float loss = 0.;

cb_to_cs& c = data.cbcs;
if (!ld.is_test_label()) { loss = CB_ALGS::get_cost_estimate(c.known_cost, c.pred_scores, ec.pred.multiclass); }

CB_ALGS::generic_output_example(all, loss, ec, ld, &c.known_cost);
sd.update(ec.test_only, !ld.is_test_label(), loss, 1.f, ec.get_num_features());
}

void finish_example(VW::workspace& all, cb& c, VW::example& ec)
template <bool uses_eval>
void output_example_prediction_cb_algs(
VW::workspace& all, const cb& /* data */, const VW::example& ec, VW::io::logger& logger)
{
output_example(all, c, ec, ec.l.cb);
VW::finish_example(all, ec);
const auto& ld = uses_eval ? ec.l.cb_eval.event : ec.l.cb;

for (auto& sink : all.final_prediction_sink)
{
all.print_by_ref(sink.get(), static_cast<float>(ec.pred.multiclass), 0, ec.tag, all.logger);
}

if (all.raw_prediction != nullptr)
{
std::stringstream output_string_stream;
for (unsigned int i = 0; i < ld.costs.size(); i++)
{
cb_class cl = ld.costs[i];
if (i > 0) { output_string_stream << ' '; }
output_string_stream << cl.action << ':' << cl.partial_prediction;
}
all.print_text_by_ref(all.raw_prediction.get(), output_string_stream.str(), ec.tag, logger);
}
}

void eval_finish_example(VW::workspace& all, cb& c, VW::example& ec)
template <bool uses_eval>
void print_update_cb_algs(
VW::workspace& all, VW::shared_data& /* sd */, const cb& data, const VW::example& ec, VW::io::logger& /* unused */)
{
output_example(all, c, ec, ec.l.cb_eval.event);
VW::finish_example(all, ec);
const auto& ld = uses_eval ? ec.l.cb_eval.event : ec.l.cb;
const auto& c = data.cbcs;

bool is_ld_test_label = ld.is_test_label();
if (!is_ld_test_label) { print_update(all, is_ld_test_label, ec, nullptr, false, &c.known_cost); }
else { print_update(all, is_ld_test_label, ec, nullptr, false, nullptr); }
}
} // namespace

Expand Down Expand Up @@ -207,7 +206,12 @@ base_learner* VW::reductions::cb_algs_setup(VW::setup_base_i& stack_builder)
auto learn_ptr = eval ? learn_eval : predict_or_learn<true>;
auto predict_ptr = eval ? predict_eval : predict_or_learn<false>;
auto label_type = eval ? VW::label_type_t::CB_EVAL : VW::label_type_t::CB;
auto finish_ex = eval ? eval_finish_example : ::finish_example;
VW::learner_update_stats_func<cb, VW::example>* update_stats_func =
eval ? update_stats_cb_algs<true> : update_stats_cb_algs<false>;
VW::learner_output_example_prediction_func<cb, VW::example>* output_example_prediction_func =
eval ? output_example_prediction_cb_algs<true> : output_example_prediction_cb_algs<false>;
VW::learner_print_update_func<cb, VW::example>* print_update_func =
eval ? print_update_cb_algs<true> : print_update_cb_algs<false>;

auto* l = make_reduction_learner(
std::move(data), base, learn_ptr, predict_ptr, stack_builder.get_setupfn_name(cb_algs_setup) + name_addition)
Expand All @@ -217,7 +221,9 @@ base_learner* VW::reductions::cb_algs_setup(VW::setup_base_i& stack_builder)
.set_output_prediction_type(VW::prediction_type_t::MULTICLASS)
.set_params_per_weight(problem_multiplier)
.set_learn_returns_prediction(eval)
.set_finish_example(finish_ex)
.set_update_stats(update_stats_func)
.set_output_example_prediction(output_example_prediction_func)
.set_print_update(print_update_func)
.build(&all.logger);

return make_base(*l);
Expand Down