Skip to content

Commit

Permalink
refactor: allow reduction to control print frequency (#4315)
Browse files Browse the repository at this point in the history
* refactor: allow reduction to control print frequency

* add check for slates
  • Loading branch information
jackgerrits authored Dec 6, 2022
1 parent 49ad32a commit adfaf02
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 7 deletions.
4 changes: 1 addition & 3 deletions vowpalwabbit/core/include/vw/core/learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -423,9 +423,7 @@ class learner

if (has_output_example_prediction()) { output_example_prediction(all, ec); }

const bool should_print_driver_update =
all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs;
if (has_print_update() && should_print_driver_update) { print_update(all, ec); }
if (has_print_update()) { print_update(all, ec); }

if (has_cleanup_example()) { cleanup_example(ec); }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,10 @@ void output_example_prediction_ccb(
void print_update_ccb(VW::workspace& all, shared_data& /* sd */, const ccb_data& data, const VW::multi_ex& ec_seq,
VW::io::logger& /* unused */)
{
if (!ec_seq.empty() && !data.no_pred)
const bool should_print_driver_update =
all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs;

if (should_print_driver_update && !ec_seq.empty() && !data.no_pred)
{
// Print progress
size_t num_features = 0;
Expand Down
5 changes: 4 additions & 1 deletion vowpalwabbit/core/src/reductions/mwt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,10 @@ void output_example_prediction_mwt(
void print_update_mwt(
VW::workspace& all, shared_data& /* sd */, const mwt& data, const VW::example& ec, VW::io::logger& /* unused */)
{
if (data.learn)
const bool should_print_driver_update =
all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs;

if (should_print_driver_update && data.learn)
{
size_t num_features = ec.get_num_features();
size_t pred = ec.pred.multiclass;
Expand Down
5 changes: 5 additions & 0 deletions vowpalwabbit/core/src/reductions/slates.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,11 @@ void output_example_prediction_slates(VW::workspace& all, const VW::reductions::
void print_update_slates(VW::workspace& all, shared_data& /* sd */, const VW::reductions::slates_data& /* data */,
const VW::multi_ex& ec_seq, VW::io::logger& /* unused */)
{
const bool should_print_driver_update =
all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs;

if (!should_print_driver_update) { return; }

const auto& predictions = ec_seq[0]->pred.decision_scores;
VW::multi_ex slots;
size_t num_features = 0;
Expand Down
10 changes: 8 additions & 2 deletions vowpalwabbit/core/src/simple_label.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,14 @@ void VW::details::update_stats_simple_label(
void VW::details::print_update_simple_label(
VW::workspace& all, shared_data& sd, const VW::example& ec, VW::io::logger& /* logger */)
{
sd.print_update(*all.trace_message, all.holdout_set_off, all.current_pass, ec.l.simple.label, ec.pred.scalar,
ec.get_num_features(), all.progress_add, all.progress_arg);
const bool should_print_driver_update =
all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs;

if (should_print_driver_update)
{
sd.print_update(*all.trace_message, all.holdout_set_off, all.current_pass, ec.l.simple.label, ec.pred.scalar,
ec.get_num_features(), all.progress_add, all.progress_arg);
}
}

void VW::details::output_example_prediction_simple_label(
Expand Down

0 comments on commit adfaf02

Please sign in to comment.