Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Add metrics collector #4407

Merged
merged 26 commits into from
Jan 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions python/pylibvw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -327,10 +327,9 @@ py::dict get_learner_metrics(vw_ptr all)
{
py::dict dictionary;

if (all->options->was_supplied("extra_metrics"))
if (all->global_metrics.are_metrics_enabled())
{
VW::metric_sink metrics;
all->l->persist_metrics(metrics);
auto metrics = all->global_metrics.collect_metrics(all->l);

python_dict_writer writer(dictionary);
metrics.visit(writer);
Expand Down
2 changes: 2 additions & 0 deletions vowpalwabbit/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ set(vw_core_headers
include/vw/core/loss_functions.h
include/vw/core/memory.h
include/vw/core/merge.h
include/vw/core/metrics_collector.h
include/vw/core/metric_sink.h
include/vw/core/model_utils.h
include/vw/core/multi_model_utils.h
Expand Down Expand Up @@ -228,6 +229,7 @@ set(vw_core_sources
src/learner.cc
src/loss_functions.cc
src/merge.cc
src/metrics_collector.cc
src/metric_sink.cc
src/multiclass.cc
src/multilabel.cc
Expand Down
5 changes: 2 additions & 3 deletions vowpalwabbit/core/include/vw/core/global_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "vw/core/error_reporting.h"
#include "vw/core/input_parser.h"
#include "vw/core/interaction_generation_state.h"
#include "vw/core/metrics_collector.h"
#include "vw/core/multi_ex.h"
#include "vw/core/version.h"
#include "vw/core/vw_fwd.h"
Expand Down Expand Up @@ -151,9 +152,7 @@ class workspace
std::unique_ptr<VW::parsers::flatbuffer::parser> flat_converter;
#endif

// This field is experimental and subject to change.
// Used to implement the external binary parser.
std::vector<std::function<void(VW::metric_sink&)>> metric_output_hooks;
VW::metrics_collector global_metrics;

// Experimental field.
// Generic parser interface to make it possible to use any external parser.
Expand Down
35 changes: 35 additions & 0 deletions vowpalwabbit/core/include/vw/core/metrics_collector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) by respective owners including Yahoo!, Microsoft, and
// individual contributors. All rights reserved. Released under a BSD (revised)
// license as described in the file LICENSE.=

#pragma once

#include "metric_sink.h"
#include "vw/core/learner_fwd.h"

#include <cstdint>
#include <functional>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>

namespace VW
{
class metrics_collector
{
public:
metrics_collector(bool enabled = false);

using metrics_callback_fn = std::function<void(VW::metric_sink&)>;

bool are_metrics_enabled() const;
void register_metrics_callback(const metrics_callback_fn& callback);
VW::metric_sink collect_metrics(LEARNER::base_learner* l = nullptr) const;

private:
bool _are_metrics_enabled;
std::vector<metrics_callback_fn> _metrics_callbacks;
};
} // namespace VW
1 change: 1 addition & 0 deletions vowpalwabbit/core/include/vw/core/reductions/metrics.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ namespace reductions
{
VW::LEARNER::base_learner* metrics_setup(VW::setup_base_i& stack_builder);
void output_metrics(VW::workspace& all);
void additional_metrics(VW::workspace& all, VW::metric_sink& sink);
} // namespace reductions
} // namespace VW
27 changes: 27 additions & 0 deletions vowpalwabbit/core/src/metrics_collector.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) by respective owners including Yahoo!, Microsoft, and
// individual contributors. All rights reserved. Released under a BSD (revised)
// license as described in the file LICENSE.

#include "vw/core/metrics_collector.h"

#include "vw/core/learner.h"
namespace VW
{
metrics_collector::metrics_collector(bool enabled) : _are_metrics_enabled(enabled) {}

bool metrics_collector::are_metrics_enabled() const { return _are_metrics_enabled; }

void metrics_collector::register_metrics_callback(const metrics_callback_fn& callback)
{
if (_are_metrics_enabled) { _metrics_callbacks.push_back(callback); }
}

metric_sink metrics_collector::collect_metrics(LEARNER::base_learner* l) const
{
VW::metric_sink sink;
if (!_are_metrics_enabled) { THROW("Metrics must be enabled to call collect_metrics"); }
if (l) { l->persist_metrics(sink); }
for (const auto& callback : _metrics_callbacks) { callback(sink); }
return sink;
}
} // namespace VW
2 changes: 1 addition & 1 deletion vowpalwabbit/core/src/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ void set_json_reader(VW::workspace& all, bool dsjson = false)

all.example_parser->decision_service_json = dsjson;

if (dsjson && all.options->was_supplied("extra_metrics"))
if (dsjson && all.global_metrics.are_metrics_enabled())
{
all.example_parser->metrics = VW::make_unique<VW::details::dsjson_metrics>();
}
Expand Down
11 changes: 3 additions & 8 deletions vowpalwabbit/core/src/reductions/baseline_challenger_cb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,8 @@ class baseline_challenger_data
VW::estimators::ChiSquared baseline;
discounted_expectation policy_expectation;
float baseline_epsilon;
bool emit_metrics;

baseline_challenger_data(bool emit_metrics, double alpha, double tau)
: baseline(alpha, tau), policy_expectation(tau), emit_metrics(emit_metrics)
{
}
baseline_challenger_data(double alpha, double tau) : baseline(alpha, tau), policy_expectation(tau) {}

static int get_chosen_action(const VW::action_scores& action_scores) { return action_scores[0].action; }

Expand Down Expand Up @@ -176,7 +172,6 @@ void save_load(baseline_challenger_data& data, VW::io_buf& io, bool read, bool t

void persist_metrics(baseline_challenger_data& data, metric_sink& metrics)
{
if (!data.emit_metrics) { return; }
auto ci = static_cast<float>(data.baseline.lower_bound_and_update());
auto exp = static_cast<float>(data.policy_expectation.current());

Expand All @@ -188,6 +183,7 @@ void persist_metrics(baseline_challenger_data& data, metric_sink& metrics)
VW::LEARNER::base_learner* VW::reductions::baseline_challenger_cb_setup(VW::setup_base_i& stack_builder)
{
options_i& options = *stack_builder.get_options();

float alpha;
float tau;
bool is_enabled = false;
Expand Down Expand Up @@ -215,8 +211,7 @@ VW::LEARNER::base_learner* VW::reductions::baseline_challenger_cb_setup(VW::setu

if (!options.was_supplied("cb_adf")) { THROW("cb_challenger requires cb_explore_adf or cb_adf"); }

bool emit_metrics = options.was_supplied("extra_metrics");
auto data = VW::make_unique<baseline_challenger_data>(emit_metrics, alpha, tau);
auto data = VW::make_unique<baseline_challenger_data>(alpha, tau);

auto* l = make_reduction_learner(std::move(data), as_multiline(stack_builder.setup_base_learner()),
learn_or_predict<true>, learn_or_predict<false>, stack_builder.get_setupfn_name(baseline_challenger_cb_setup))
Expand Down
6 changes: 2 additions & 4 deletions vowpalwabbit/core/src/reductions/cb/cb_explore_adf_bag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,9 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_bag_setup(VW::setup_ba
VW::LEARNER::multi_learner* base = as_multiline(stack_builder.setup_base_learner());
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

using explore_type = cb_explore_adf_base<cb_explore_adf_bag>;
auto data = VW::make_unique<explore_type>(
with_metrics, epsilon, VW::cast_to_smaller_type<size_t>(bag_size), greedify, first_only, all.get_random_state());
auto data = VW::make_unique<explore_type>(all.global_metrics.are_metrics_enabled(), epsilon,
VW::cast_to_smaller_type<size_t>(bag_size), greedify, first_only, all.get_random_state());
auto* l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict,
stack_builder.get_setupfn_name(cb_explore_adf_bag_setup))
.set_input_label_type(VW::label_type_t::CB)
Expand Down
8 changes: 3 additions & 5 deletions vowpalwabbit/core/src/reductions/cb/cb_explore_adf_cover.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,14 +323,12 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_cover_setup(VW::setup_
epsilon_decay = true;
}

bool with_metrics = options.was_supplied("extra_metrics");

auto* scorer = VW::LEARNER::as_singleline(base->get_learner_by_name_prefix("scorer"));

using explore_type = cb_explore_adf_base<cb_explore_adf_cover>;
auto data =
VW::make_unique<explore_type>(with_metrics, VW::cast_to_smaller_type<size_t>(cover_size), psi, nounif, epsilon,
epsilon_decay, first_only, as_multiline(all.cost_sensitive), scorer, cb_type, all.model_file_ver, all.logger);
auto data = VW::make_unique<explore_type>(all.global_metrics.are_metrics_enabled(),
VW::cast_to_smaller_type<size_t>(cover_size), psi, nounif, epsilon, epsilon_decay, first_only,
as_multiline(all.cost_sensitive), scorer, cb_type, all.model_file_ver, all.logger);
auto* l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict,
stack_builder.get_setupfn_name(cb_explore_adf_cover_setup))
.set_input_label_type(VW::label_type_t::CB)
Expand Down
6 changes: 2 additions & 4 deletions vowpalwabbit/core/src/reductions/cb/cb_explore_adf_first.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,9 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_first_setup(VW::setup_
multi_learner* base = as_multiline(stack_builder.setup_base_learner());
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

using explore_type = cb_explore_adf_base<cb_explore_adf_first>;
auto data =
VW::make_unique<explore_type>(with_metrics, VW::cast_to_smaller_type<size_t>(tau), epsilon, all.model_file_ver);
auto data = VW::make_unique<explore_type>(
all.global_metrics.are_metrics_enabled(), VW::cast_to_smaller_type<size_t>(tau), epsilon, all.model_file_ver);

if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); }
auto* l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict,
Expand Down
4 changes: 1 addition & 3 deletions vowpalwabbit/core/src/reductions/cb/cb_explore_adf_greedy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,8 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_greedy_setup(VW::setup
VW::LEARNER::multi_learner* base = as_multiline(stack_builder.setup_base_learner());
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

using explore_type = cb_explore_adf_base<cb_explore_adf_greedy>;
auto data = VW::make_unique<explore_type>(with_metrics, epsilon, first_only);
auto data = VW::make_unique<explore_type>(all.global_metrics.are_metrics_enabled(), epsilon, first_only);

if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); }
auto* l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,20 +406,19 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_large_action_space_set
VW::LEARNER::multi_learner* base = as_multiline(stack_builder.setup_base_learner());
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

if (use_two_pass_svd_impl)
{
auto impl_type = implementation_type::two_pass_svd;
return make_las_with_impl<two_pass_svd_impl, one_rank_spanner_state>(stack_builder, base, impl_type, all,
with_metrics, d, gamma_scale, gamma_exponent, c, apply_shrink_factor, thread_pool_size, block_size,
all.global_metrics.are_metrics_enabled(), d, gamma_scale, gamma_exponent, c, apply_shrink_factor,
thread_pool_size, block_size,
/*use_explicit_simd=*/false);
}
else
{
auto impl_type = implementation_type::one_pass_svd;
return make_las_with_impl<one_pass_svd_impl, one_rank_spanner_state>(stack_builder, base, impl_type, all,
with_metrics, d, gamma_scale, gamma_exponent, c, apply_shrink_factor, thread_pool_size, block_size,
use_simd_in_one_pass_svd_impl);
all.global_metrics.are_metrics_enabled(), d, gamma_scale, gamma_exponent, c, apply_shrink_factor,
thread_pool_size, block_size, use_simd_in_one_pass_svd_impl);
}
}
4 changes: 1 addition & 3 deletions vowpalwabbit/core/src/reductions/cb/cb_explore_adf_regcb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,9 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_regcb_setup(VW::setup_
multi_learner* base = as_multiline(stack_builder.setup_base_learner());
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

using explore_type = cb_explore_adf_base<cb_explore_adf_regcb>;
auto data = VW::make_unique<explore_type>(
with_metrics, regcbopt, c0, first_only, min_cb_cost, max_cb_cost, all.model_file_ver);
all.global_metrics.are_metrics_enabled(), regcbopt, c0, first_only, min_cb_cost, max_cb_cost, all.model_file_ver);
auto* l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict,
stack_builder.get_setupfn_name(cb_explore_adf_regcb_setup))
.set_input_label_type(VW::label_type_t::CB)
Expand Down
6 changes: 2 additions & 4 deletions vowpalwabbit/core/src/reductions/cb/cb_explore_adf_rnd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,9 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_rnd_setup(VW::setup_ba
multi_learner* base = as_multiline(stack_builder.setup_base_learner());
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

using explore_type = cb_explore_adf_base<cb_explore_adf_rnd>;
auto data = VW::make_unique<explore_type>(
with_metrics, epsilon, alpha, invlambda, numrnd, base->increment * problem_multiplier, &all);
auto data = VW::make_unique<explore_type>(all.global_metrics.are_metrics_enabled(), epsilon, alpha, invlambda, numrnd,
base->increment * problem_multiplier, &all);

if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); }
auto* l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,8 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_softmax_setup(VW::setu
VW::LEARNER::multi_learner* base = as_multiline(stack_builder.setup_base_learner());
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

using explore_type = cb_explore_adf_base<cb_explore_adf_softmax>;
auto data = VW::make_unique<explore_type>(with_metrics, epsilon, lambda);
auto data = VW::make_unique<explore_type>(all.global_metrics.are_metrics_enabled(), epsilon, lambda);

if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); }
auto* l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,11 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_squarecb_setup(VW::set
multi_learner* base = as_multiline(stack_builder.setup_base_learner());
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); }

using explore_type = cb_explore_adf_base<cb_explore_adf_squarecb>;
auto data = VW::make_unique<explore_type>(
with_metrics, gamma_scale, gamma_exponent, elim, c0, min_cb_cost, max_cb_cost, all.model_file_ver, epsilon);
auto data = VW::make_unique<explore_type>(all.global_metrics.are_metrics_enabled(), gamma_scale, gamma_exponent, elim,
c0, min_cb_cost, max_cb_cost, all.model_file_ver, epsilon);
auto* l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict,
stack_builder.get_setupfn_name(cb_explore_adf_squarecb_setup))
.set_input_label_type(VW::label_type_t::CB)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,8 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_synthcover_setup(VW::s
VW::LEARNER::multi_learner* base = as_multiline(stack_builder.setup_base_learner());
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

using explore_type = cb_explore_adf_base<cb_explore_adf_synthcover>;
auto data = VW::make_unique<explore_type>(with_metrics, epsilon, psi,
auto data = VW::make_unique<explore_type>(all.global_metrics.are_metrics_enabled(), epsilon, psi,
VW::cast_to_smaller_type<size_t>(synthcoversize), all.get_random_state(), all.model_file_ver);
auto* l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict,
stack_builder.get_setupfn_name(cb_explore_adf_synthcover_setup))
Expand Down
38 changes: 20 additions & 18 deletions vowpalwabbit/core/src/reductions/metrics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ void insert_dsjson_metrics(const VW::details::dsjson_metrics* ds_metrics, VW::me
class metrics_data
{
public:
std::string out_file;
size_t learn_count = 0;
size_t predict_count = 0;
};
Expand Down Expand Up @@ -113,6 +112,7 @@ void list_to_json_file(const std::string& filename, const VW::metric_sink& metri
}
else { logger.err_warn("skipping metrics. could not open file for metrics: {}", filename); }
}

void persist(metrics_data& data, VW::metric_sink& metrics)
{
metrics.set_uint("total_predict_calls", data.predict_count);
Expand All @@ -135,40 +135,42 @@ void predict_or_learn(metrics_data& data, T& base, E& ec)
}
} // namespace

void VW::reductions::additional_metrics(VW::workspace& all, VW::metric_sink& sink)
{
sink.set_uint("total_log_calls", all.logger.get_log_count());

std::vector<std::string> enabled_reductions;
if (all.l != nullptr) { all.l->get_enabled_reductions(enabled_reductions); }
insert_dsjson_metrics(all.example_parser->metrics.get(), sink, enabled_reductions);
}

void VW::reductions::output_metrics(VW::workspace& all)
{
if (all.options->was_supplied("extra_metrics"))
metrics_collector& manager = all.global_metrics;
if (manager.are_metrics_enabled())
{
std::string filename = all.options->get_typed_option<std::string>("extra_metrics").value();
VW::metric_sink list_metrics;

all.l->persist_metrics(list_metrics);

for (auto& metric_hook : all.metric_output_hooks) { metric_hook(list_metrics); }

list_metrics.set_uint("total_log_calls", all.logger.get_log_count());

std::vector<std::string> enabled_reductions;
if (all.l != nullptr) { all.l->get_enabled_reductions(enabled_reductions); }
insert_dsjson_metrics(all.example_parser->metrics.get(), list_metrics, enabled_reductions);

list_to_json_file(filename, list_metrics, all.logger);
list_to_json_file(filename, manager.collect_metrics(all.l), all.logger);
}
}

VW::LEARNER::base_learner* VW::reductions::metrics_setup(VW::setup_base_i& stack_builder)
{
options_i& options = *stack_builder.get_options();
VW::config::options_i& options = *stack_builder.get_options();
VW::workspace& all = *stack_builder.get_all_pointer();

auto data = VW::make_unique<metrics_data>();

std::string out_file;
option_group_definition new_options("[Reduction] Debug Metrics");
new_options.add(make_option("extra_metrics", data->out_file)
new_options.add(make_option("extra_metrics", out_file)
.necessary()
.help("Specify filename to write metrics to. Note: There is no fixed schema"));

if (!options.add_parse_and_check_necessary(new_options)) { return nullptr; }

if (data->out_file.empty()) THROW("extra_metrics argument (output filename) is missing.");
if (out_file.empty()) THROW("extra_metrics argument (output filename) is missing.");
all.global_metrics = VW::metrics_collector(true);

auto* base_learner = stack_builder.setup_base_learner();

Expand Down
Loading