Skip to content

Commit

Permalink
refactor: Add metrics collector (#4407)
Browse files Browse the repository at this point in the history
* merge

* compile

* uncomment

* revert

* add manager

* more additions

* clang

* remove options parsing

* clang

* address comments

* comments

* clang

* remove output hooks

* unused

* comments

* clang

* update name to metrics_collector

* remove filename

* remove filename .h

* remove fn

* fix construction

* revert settings
  • Loading branch information
bassmang authored Jan 3, 2023
1 parent 72077e5 commit 9741c69
Show file tree
Hide file tree
Showing 23 changed files with 121 additions and 82 deletions.
5 changes: 2 additions & 3 deletions python/pylibvw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -327,10 +327,9 @@ py::dict get_learner_metrics(vw_ptr all)
{
py::dict dictionary;

if (all->options->was_supplied("extra_metrics"))
if (all->global_metrics.are_metrics_enabled())
{
VW::metric_sink metrics;
all->l->persist_metrics(metrics);
auto metrics = all->global_metrics.collect_metrics(all->l);

python_dict_writer writer(dictionary);
metrics.visit(writer);
Expand Down
2 changes: 2 additions & 0 deletions vowpalwabbit/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ set(vw_core_headers
include/vw/core/loss_functions.h
include/vw/core/memory.h
include/vw/core/merge.h
include/vw/core/metrics_collector.h
include/vw/core/metric_sink.h
include/vw/core/model_utils.h
include/vw/core/multi_model_utils.h
Expand Down Expand Up @@ -228,6 +229,7 @@ set(vw_core_sources
src/learner.cc
src/loss_functions.cc
src/merge.cc
src/metrics_collector.cc
src/metric_sink.cc
src/multiclass.cc
src/multilabel.cc
Expand Down
5 changes: 2 additions & 3 deletions vowpalwabbit/core/include/vw/core/global_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "vw/core/error_reporting.h"
#include "vw/core/input_parser.h"
#include "vw/core/interaction_generation_state.h"
#include "vw/core/metrics_collector.h"
#include "vw/core/multi_ex.h"
#include "vw/core/version.h"
#include "vw/core/vw_fwd.h"
Expand Down Expand Up @@ -151,9 +152,7 @@ class workspace
std::unique_ptr<VW::parsers::flatbuffer::parser> flat_converter;
#endif

// This field is experimental and subject to change.
// Used to implement the external binary parser.
std::vector<std::function<void(VW::metric_sink&)>> metric_output_hooks;
VW::metrics_collector global_metrics;

// Experimental field.
// Generic parser interface to make it possible to use any external parser.
Expand Down
35 changes: 35 additions & 0 deletions vowpalwabbit/core/include/vw/core/metrics_collector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) by respective owners including Yahoo!, Microsoft, and
// individual contributors. All rights reserved. Released under a BSD (revised)
// license as described in the file LICENSE.=

#pragma once

#include "metric_sink.h"
#include "vw/core/learner_fwd.h"

#include <cstdint>
#include <functional>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>

namespace VW
{
class metrics_collector
{
public:
metrics_collector(bool enabled = false);

using metrics_callback_fn = std::function<void(VW::metric_sink&)>;

bool are_metrics_enabled() const;
void register_metrics_callback(const metrics_callback_fn& callback);
VW::metric_sink collect_metrics(LEARNER::base_learner* l = nullptr) const;

private:
bool _are_metrics_enabled;
std::vector<metrics_callback_fn> _metrics_callbacks;
};
} // namespace VW
1 change: 1 addition & 0 deletions vowpalwabbit/core/include/vw/core/reductions/metrics.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ namespace reductions
{
VW::LEARNER::base_learner* metrics_setup(VW::setup_base_i& stack_builder);
void output_metrics(VW::workspace& all);
void additional_metrics(VW::workspace& all, VW::metric_sink& sink);
} // namespace reductions
} // namespace VW
27 changes: 27 additions & 0 deletions vowpalwabbit/core/src/metrics_collector.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) by respective owners including Yahoo!, Microsoft, and
// individual contributors. All rights reserved. Released under a BSD (revised)
// license as described in the file LICENSE.

#include "vw/core/metrics_collector.h"

#include "vw/core/learner.h"
namespace VW
{
metrics_collector::metrics_collector(bool enabled) : _are_metrics_enabled(enabled) {}

bool metrics_collector::are_metrics_enabled() const { return _are_metrics_enabled; }

void metrics_collector::register_metrics_callback(const metrics_callback_fn& callback)
{
if (_are_metrics_enabled) { _metrics_callbacks.push_back(callback); }
}

metric_sink metrics_collector::collect_metrics(LEARNER::base_learner* l) const
{
VW::metric_sink sink;
if (!_are_metrics_enabled) { THROW("Metrics must be enabled to call collect_metrics"); }
if (l) { l->persist_metrics(sink); }
for (const auto& callback : _metrics_callbacks) { callback(sink); }
return sink;
}
} // namespace VW
2 changes: 1 addition & 1 deletion vowpalwabbit/core/src/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ void set_json_reader(VW::workspace& all, bool dsjson = false)

all.example_parser->decision_service_json = dsjson;

if (dsjson && all.options->was_supplied("extra_metrics"))
if (dsjson && all.global_metrics.are_metrics_enabled())
{
all.example_parser->metrics = VW::make_unique<VW::details::dsjson_metrics>();
}
Expand Down
11 changes: 3 additions & 8 deletions vowpalwabbit/core/src/reductions/baseline_challenger_cb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,8 @@ class baseline_challenger_data
VW::estimators::ChiSquared baseline;
discounted_expectation policy_expectation;
float baseline_epsilon;
bool emit_metrics;

baseline_challenger_data(bool emit_metrics, double alpha, double tau)
: baseline(alpha, tau), policy_expectation(tau), emit_metrics(emit_metrics)
{
}
baseline_challenger_data(double alpha, double tau) : baseline(alpha, tau), policy_expectation(tau) {}

static int get_chosen_action(const VW::action_scores& action_scores) { return action_scores[0].action; }

Expand Down Expand Up @@ -176,7 +172,6 @@ void save_load(baseline_challenger_data& data, VW::io_buf& io, bool read, bool t

void persist_metrics(baseline_challenger_data& data, metric_sink& metrics)
{
if (!data.emit_metrics) { return; }
auto ci = static_cast<float>(data.baseline.lower_bound_and_update());
auto exp = static_cast<float>(data.policy_expectation.current());

Expand All @@ -188,6 +183,7 @@ void persist_metrics(baseline_challenger_data& data, metric_sink& metrics)
VW::LEARNER::base_learner* VW::reductions::baseline_challenger_cb_setup(VW::setup_base_i& stack_builder)
{
options_i& options = *stack_builder.get_options();

float alpha;
float tau;
bool is_enabled = false;
Expand Down Expand Up @@ -215,8 +211,7 @@ VW::LEARNER::base_learner* VW::reductions::baseline_challenger_cb_setup(VW::setu

if (!options.was_supplied("cb_adf")) { THROW("cb_challenger requires cb_explore_adf or cb_adf"); }

bool emit_metrics = options.was_supplied("extra_metrics");
auto data = VW::make_unique<baseline_challenger_data>(emit_metrics, alpha, tau);
auto data = VW::make_unique<baseline_challenger_data>(alpha, tau);

auto* l = make_reduction_learner(std::move(data), as_multiline(stack_builder.setup_base_learner()),
learn_or_predict<true>, learn_or_predict<false>, stack_builder.get_setupfn_name(baseline_challenger_cb_setup))
Expand Down
6 changes: 2 additions & 4 deletions vowpalwabbit/core/src/reductions/cb/cb_explore_adf_bag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,9 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_bag_setup(VW::setup_ba
VW::LEARNER::multi_learner* base = as_multiline(stack_builder.setup_base_learner());
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

using explore_type = cb_explore_adf_base<cb_explore_adf_bag>;
auto data = VW::make_unique<explore_type>(
with_metrics, epsilon, VW::cast_to_smaller_type<size_t>(bag_size), greedify, first_only, all.get_random_state());
auto data = VW::make_unique<explore_type>(all.global_metrics.are_metrics_enabled(), epsilon,
VW::cast_to_smaller_type<size_t>(bag_size), greedify, first_only, all.get_random_state());
auto* l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict,
stack_builder.get_setupfn_name(cb_explore_adf_bag_setup))
.set_input_label_type(VW::label_type_t::CB)
Expand Down
8 changes: 3 additions & 5 deletions vowpalwabbit/core/src/reductions/cb/cb_explore_adf_cover.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,14 +323,12 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_cover_setup(VW::setup_
epsilon_decay = true;
}

bool with_metrics = options.was_supplied("extra_metrics");

auto* scorer = VW::LEARNER::as_singleline(base->get_learner_by_name_prefix("scorer"));

using explore_type = cb_explore_adf_base<cb_explore_adf_cover>;
auto data =
VW::make_unique<explore_type>(with_metrics, VW::cast_to_smaller_type<size_t>(cover_size), psi, nounif, epsilon,
epsilon_decay, first_only, as_multiline(all.cost_sensitive), scorer, cb_type, all.model_file_ver, all.logger);
auto data = VW::make_unique<explore_type>(all.global_metrics.are_metrics_enabled(),
VW::cast_to_smaller_type<size_t>(cover_size), psi, nounif, epsilon, epsilon_decay, first_only,
as_multiline(all.cost_sensitive), scorer, cb_type, all.model_file_ver, all.logger);
auto* l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict,
stack_builder.get_setupfn_name(cb_explore_adf_cover_setup))
.set_input_label_type(VW::label_type_t::CB)
Expand Down
6 changes: 2 additions & 4 deletions vowpalwabbit/core/src/reductions/cb/cb_explore_adf_first.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,9 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_first_setup(VW::setup_
multi_learner* base = as_multiline(stack_builder.setup_base_learner());
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

using explore_type = cb_explore_adf_base<cb_explore_adf_first>;
auto data =
VW::make_unique<explore_type>(with_metrics, VW::cast_to_smaller_type<size_t>(tau), epsilon, all.model_file_ver);
auto data = VW::make_unique<explore_type>(
all.global_metrics.are_metrics_enabled(), VW::cast_to_smaller_type<size_t>(tau), epsilon, all.model_file_ver);

if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); }
auto* l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict,
Expand Down
4 changes: 1 addition & 3 deletions vowpalwabbit/core/src/reductions/cb/cb_explore_adf_greedy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,8 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_greedy_setup(VW::setup
VW::LEARNER::multi_learner* base = as_multiline(stack_builder.setup_base_learner());
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

using explore_type = cb_explore_adf_base<cb_explore_adf_greedy>;
auto data = VW::make_unique<explore_type>(with_metrics, epsilon, first_only);
auto data = VW::make_unique<explore_type>(all.global_metrics.are_metrics_enabled(), epsilon, first_only);

if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); }
auto* l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,20 +406,19 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_large_action_space_set
VW::LEARNER::multi_learner* base = as_multiline(stack_builder.setup_base_learner());
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

if (use_two_pass_svd_impl)
{
auto impl_type = implementation_type::two_pass_svd;
return make_las_with_impl<two_pass_svd_impl, one_rank_spanner_state>(stack_builder, base, impl_type, all,
with_metrics, d, gamma_scale, gamma_exponent, c, apply_shrink_factor, thread_pool_size, block_size,
all.global_metrics.are_metrics_enabled(), d, gamma_scale, gamma_exponent, c, apply_shrink_factor,
thread_pool_size, block_size,
/*use_explicit_simd=*/false);
}
else
{
auto impl_type = implementation_type::one_pass_svd;
return make_las_with_impl<one_pass_svd_impl, one_rank_spanner_state>(stack_builder, base, impl_type, all,
with_metrics, d, gamma_scale, gamma_exponent, c, apply_shrink_factor, thread_pool_size, block_size,
use_simd_in_one_pass_svd_impl);
all.global_metrics.are_metrics_enabled(), d, gamma_scale, gamma_exponent, c, apply_shrink_factor,
thread_pool_size, block_size, use_simd_in_one_pass_svd_impl);
}
}
4 changes: 1 addition & 3 deletions vowpalwabbit/core/src/reductions/cb/cb_explore_adf_regcb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,9 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_regcb_setup(VW::setup_
multi_learner* base = as_multiline(stack_builder.setup_base_learner());
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

using explore_type = cb_explore_adf_base<cb_explore_adf_regcb>;
auto data = VW::make_unique<explore_type>(
with_metrics, regcbopt, c0, first_only, min_cb_cost, max_cb_cost, all.model_file_ver);
all.global_metrics.are_metrics_enabled(), regcbopt, c0, first_only, min_cb_cost, max_cb_cost, all.model_file_ver);
auto* l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict,
stack_builder.get_setupfn_name(cb_explore_adf_regcb_setup))
.set_input_label_type(VW::label_type_t::CB)
Expand Down
6 changes: 2 additions & 4 deletions vowpalwabbit/core/src/reductions/cb/cb_explore_adf_rnd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,9 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_rnd_setup(VW::setup_ba
multi_learner* base = as_multiline(stack_builder.setup_base_learner());
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

using explore_type = cb_explore_adf_base<cb_explore_adf_rnd>;
auto data = VW::make_unique<explore_type>(
with_metrics, epsilon, alpha, invlambda, numrnd, base->increment * problem_multiplier, &all);
auto data = VW::make_unique<explore_type>(all.global_metrics.are_metrics_enabled(), epsilon, alpha, invlambda, numrnd,
base->increment * problem_multiplier, &all);

if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); }
auto* l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,8 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_softmax_setup(VW::setu
VW::LEARNER::multi_learner* base = as_multiline(stack_builder.setup_base_learner());
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

using explore_type = cb_explore_adf_base<cb_explore_adf_softmax>;
auto data = VW::make_unique<explore_type>(with_metrics, epsilon, lambda);
auto data = VW::make_unique<explore_type>(all.global_metrics.are_metrics_enabled(), epsilon, lambda);

if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); }
auto* l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,11 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_squarecb_setup(VW::set
multi_learner* base = as_multiline(stack_builder.setup_base_learner());
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); }

using explore_type = cb_explore_adf_base<cb_explore_adf_squarecb>;
auto data = VW::make_unique<explore_type>(
with_metrics, gamma_scale, gamma_exponent, elim, c0, min_cb_cost, max_cb_cost, all.model_file_ver, epsilon);
auto data = VW::make_unique<explore_type>(all.global_metrics.are_metrics_enabled(), gamma_scale, gamma_exponent, elim,
c0, min_cb_cost, max_cb_cost, all.model_file_ver, epsilon);
auto* l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict,
stack_builder.get_setupfn_name(cb_explore_adf_squarecb_setup))
.set_input_label_type(VW::label_type_t::CB)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,8 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_synthcover_setup(VW::s
VW::LEARNER::multi_learner* base = as_multiline(stack_builder.setup_base_learner());
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

using explore_type = cb_explore_adf_base<cb_explore_adf_synthcover>;
auto data = VW::make_unique<explore_type>(with_metrics, epsilon, psi,
auto data = VW::make_unique<explore_type>(all.global_metrics.are_metrics_enabled(), epsilon, psi,
VW::cast_to_smaller_type<size_t>(synthcoversize), all.get_random_state(), all.model_file_ver);
auto* l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict,
stack_builder.get_setupfn_name(cb_explore_adf_synthcover_setup))
Expand Down
38 changes: 20 additions & 18 deletions vowpalwabbit/core/src/reductions/metrics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ void insert_dsjson_metrics(const VW::details::dsjson_metrics* ds_metrics, VW::me
class metrics_data
{
public:
std::string out_file;
size_t learn_count = 0;
size_t predict_count = 0;
};
Expand Down Expand Up @@ -113,6 +112,7 @@ void list_to_json_file(const std::string& filename, const VW::metric_sink& metri
}
else { logger.err_warn("skipping metrics. could not open file for metrics: {}", filename); }
}

void persist(metrics_data& data, VW::metric_sink& metrics)
{
metrics.set_uint("total_predict_calls", data.predict_count);
Expand All @@ -135,40 +135,42 @@ void predict_or_learn(metrics_data& data, T& base, E& ec)
}
} // namespace

void VW::reductions::additional_metrics(VW::workspace& all, VW::metric_sink& sink)
{
sink.set_uint("total_log_calls", all.logger.get_log_count());

std::vector<std::string> enabled_reductions;
if (all.l != nullptr) { all.l->get_enabled_reductions(enabled_reductions); }
insert_dsjson_metrics(all.example_parser->metrics.get(), sink, enabled_reductions);
}

void VW::reductions::output_metrics(VW::workspace& all)
{
if (all.options->was_supplied("extra_metrics"))
metrics_collector& manager = all.global_metrics;
if (manager.are_metrics_enabled())
{
std::string filename = all.options->get_typed_option<std::string>("extra_metrics").value();
VW::metric_sink list_metrics;

all.l->persist_metrics(list_metrics);

for (auto& metric_hook : all.metric_output_hooks) { metric_hook(list_metrics); }

list_metrics.set_uint("total_log_calls", all.logger.get_log_count());

std::vector<std::string> enabled_reductions;
if (all.l != nullptr) { all.l->get_enabled_reductions(enabled_reductions); }
insert_dsjson_metrics(all.example_parser->metrics.get(), list_metrics, enabled_reductions);

list_to_json_file(filename, list_metrics, all.logger);
list_to_json_file(filename, manager.collect_metrics(all.l), all.logger);
}
}

VW::LEARNER::base_learner* VW::reductions::metrics_setup(VW::setup_base_i& stack_builder)
{
options_i& options = *stack_builder.get_options();
VW::config::options_i& options = *stack_builder.get_options();
VW::workspace& all = *stack_builder.get_all_pointer();

auto data = VW::make_unique<metrics_data>();

std::string out_file;
option_group_definition new_options("[Reduction] Debug Metrics");
new_options.add(make_option("extra_metrics", data->out_file)
new_options.add(make_option("extra_metrics", out_file)
.necessary()
.help("Specify filename to write metrics to. Note: There is no fixed schema"));

if (!options.add_parse_and_check_necessary(new_options)) { return nullptr; }

if (data->out_file.empty()) THROW("extra_metrics argument (output filename) is missing.");
if (out_file.empty()) THROW("extra_metrics argument (output filename) is missing.");
all.global_metrics = VW::metrics_collector(true);

auto* base_learner = stack_builder.setup_base_learner();

Expand Down
Loading

0 comments on commit 9741c69

Please sign in to comment.