Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: cb_algs finish functions #4409

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions vowpalwabbit/core/include/vw/core/reductions/cb/cb_algs.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,4 @@ inline bool example_is_newline_not_header(VW::example const& ec)
return (VW::example_is_newline(ec) && !CB::ec_is_example_header(ec));
}

void generic_output_example(
VW::workspace& all, float loss, const VW::example& ec, const CB::label& ld, const CB::cb_class* known_cost);

} // namespace CB_ALGS
96 changes: 47 additions & 49 deletions vowpalwabbit/core/src/reductions/cb/cb_algs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,42 +20,14 @@ using namespace VW::config;
using namespace CB;
using namespace GEN_CS;

namespace CB_ALGS
{
void generic_output_example(
VW::workspace& all, float loss, const VW::example& ec, const CB::label& ld, const CB::cb_class* known_cost)
{
all.sd->update(ec.test_only, !ld.is_test_label(), loss, 1.f, ec.get_num_features());

for (auto& sink : all.final_prediction_sink)
{
all.print_by_ref(sink.get(), static_cast<float>(ec.pred.multiclass), 0, ec.tag, all.logger);
}

if (all.raw_prediction != nullptr)
{
std::stringstream output_string_stream;
for (unsigned int i = 0; i < ld.costs.size(); i++)
{
cb_class cl = ld.costs[i];
if (i > 0) { output_string_stream << ' '; }
output_string_stream << cl.action << ':' << cl.partial_prediction;
}
all.print_text_by_ref(all.raw_prediction.get(), output_string_stream.str(), ec.tag, all.logger);
}

bool is_ld_test_label = ld.is_test_label();
if (!is_ld_test_label) { print_update(all, is_ld_test_label, ec, nullptr, false, known_cost); }
else { print_update(all, is_ld_test_label, ec, nullptr, false, nullptr); }
}
} // namespace CB_ALGS
namespace
{
class cb
{
public:
cb_to_cs cbcs;
VW::io::logger logger;
bool eval = false;

cb(VW::io::logger logger) : logger(std::move(logger)) {}
};
Expand Down Expand Up @@ -109,26 +81,50 @@ void learn_eval(cb& data, single_learner&, VW::example& ec)
ec.pred.multiclass = ec.l.cb_eval.action;
}

void output_example(VW::workspace& all, cb& data, const VW::example& ec, const CB::label& ld)
void update_stats_cb_algs(
const VW::workspace&, VW::shared_data& sd, const cb& data, const VW::example& ec, VW::io::logger&)
{
const auto& ld = data.eval ? ec.l.cb_eval.event : ec.l.cb;
const auto& c = data.cbcs;
float loss = 0.;

cb_to_cs& c = data.cbcs;
if (!ld.is_test_label()) { loss = CB_ALGS::get_cost_estimate(c.known_cost, c.pred_scores, ec.pred.multiclass); }

CB_ALGS::generic_output_example(all, loss, ec, ld, &c.known_cost);
sd.update(ec.test_only, !ld.is_test_label(), loss, 1.f, ec.get_num_features());
}

void finish_example(VW::workspace& all, cb& c, VW::example& ec)
void output_example_prediction_cb_algs(
VW::workspace& all, const cb& data, const VW::example& ec, VW::io::logger& logger)
{
output_example(all, c, ec, ec.l.cb);
VW::finish_example(all, ec);
const auto& ld = data.eval ? ec.l.cb_eval.event : ec.l.cb;

for (auto& sink : all.final_prediction_sink)
{
all.print_by_ref(sink.get(), static_cast<float>(ec.pred.multiclass), 0, ec.tag, all.logger);
}

if (all.raw_prediction != nullptr)
{
std::stringstream output_string_stream;
for (unsigned int i = 0; i < ld.costs.size(); i++)
{
cb_class cl = ld.costs[i];
if (i > 0) { output_string_stream << ' '; }
output_string_stream << cl.action << ':' << cl.partial_prediction;
}
all.print_text_by_ref(all.raw_prediction.get(), output_string_stream.str(), ec.tag, all.logger);
}
}

void eval_finish_example(VW::workspace& all, cb& c, VW::example& ec)
void print_update_cb_algs(
VW::workspace& all, VW::shared_data& sd, const cb& data, const VW::example& ec, VW::io::logger&)
{
output_example(all, c, ec, ec.l.cb_eval.event);
VW::finish_example(all, ec);
const auto& ld = data.eval ? ec.l.cb_eval.event : ec.l.cb;
const auto& c = data.cbcs;

bool is_ld_test_label = ld.is_test_label();
if (!is_ld_test_label) { print_update(all, is_ld_test_label, ec, nullptr, false, &c.known_cost); }
else { print_update(all, is_ld_test_label, ec, nullptr, false, nullptr); }
}
} // namespace

Expand All @@ -138,7 +134,6 @@ base_learner* VW::reductions::cb_algs_setup(VW::setup_base_i& stack_builder)
VW::workspace& all = *stack_builder.get_all_pointer();
auto data = VW::make_unique<cb>(all.logger);
std::string type_string = "dr";
bool eval = false;
bool force_legacy = true;

option_group_definition new_options("[Reduction] Contextual Bandit");
Expand All @@ -152,14 +147,14 @@ base_learner* VW::reductions::cb_algs_setup(VW::setup_base_i& stack_builder)
.default_value("dr")
.one_of({"ips", "dm", "dr", "mtr", "sm"})
.help("Contextual bandit method to use"))
.add(make_option("eval", eval).help("Evaluate a policy rather than optimizing"))
.add(make_option("eval", data->eval).help("Evaluate a policy rather than optimizing"))
.add(make_option("cb_force_legacy", force_legacy)
.keep()
.help("Default to non-adf cb implementation (cb_to_cb_adf)"));

if (!options.add_parse_and_check_necessary(new_options)) { return nullptr; }

if (!eval && !force_legacy) { return nullptr; }
if (!data->eval && !force_legacy) { return nullptr; }

// Ensure serialization of this option in all cases.
if (!options.was_supplied("cb_type"))
Expand All @@ -177,7 +172,7 @@ base_learner* VW::reductions::cb_algs_setup(VW::setup_base_i& stack_builder)
case VW::cb_type_t::DR:
break;
case VW::cb_type_t::DM:
if (eval) THROW("direct method can not be used for evaluation --- it is biased.");
if (data->eval) THROW("direct method can not be used for evaluation --- it is biased.");
problem_multiplier = 1;
break;
case VW::cb_type_t::IPS:
Expand All @@ -199,15 +194,16 @@ base_learner* VW::reductions::cb_algs_setup(VW::setup_base_i& stack_builder)
}

auto base = as_singleline(stack_builder.setup_base_learner());
if (eval) { all.example_parser->lbl_parser = CB_EVAL::cb_eval; }
if (data->eval) { all.example_parser->lbl_parser = CB_EVAL::cb_eval; }
else { all.example_parser->lbl_parser = CB::cb_label; }
c.scorer = VW::LEARNER::as_singleline(base->get_learner_by_name_prefix("scorer"));

std::string name_addition = eval ? "-eval" : "";
auto learn_ptr = eval ? learn_eval : predict_or_learn<true>;
auto predict_ptr = eval ? predict_eval : predict_or_learn<false>;
auto label_type = eval ? VW::label_type_t::CB_EVAL : VW::label_type_t::CB;
auto finish_ex = eval ? eval_finish_example : ::finish_example;
std::string name_addition = data->eval ? "-eval" : "";
auto learn_ptr = data->eval ? learn_eval : predict_or_learn<true>;
auto predict_ptr = data->eval ? predict_eval : predict_or_learn<false>;
auto label_type = data->eval ? VW::label_type_t::CB_EVAL : VW::label_type_t::CB;
// needed because we move data into the learner, but still need to set a value based on eval
auto eval = data->eval;

auto* l = make_reduction_learner(
std::move(data), base, learn_ptr, predict_ptr, stack_builder.get_setupfn_name(cb_algs_setup) + name_addition)
Expand All @@ -217,7 +213,9 @@ base_learner* VW::reductions::cb_algs_setup(VW::setup_base_i& stack_builder)
.set_output_prediction_type(VW::prediction_type_t::MULTICLASS)
.set_params_per_weight(problem_multiplier)
.set_learn_returns_prediction(eval)
.set_finish_example(finish_ex)
.set_update_stats(::update_stats_cb_algs)
.set_output_example_prediction(::output_example_prediction_cb_algs)
.set_print_update(::print_update_cb_algs)
.build(&all.logger);

return make_base(*l);
Expand Down