Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: [automl] remove estimator dependency from oracle #4117

Merged
merged 4 commits into from
Aug 15, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/unit_test/automl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ BOOST_AUTO_TEST_CASE(one_diff_impl_unittest)
auto& champ_interactions = estimators[CHAMP].first.live_interactions;

BOOST_CHECK_EQUAL(champ_interactions.size(), 0);
gen_interactions(false, ns_counter, oracle._interaction_type, configs, estimators, 0);
gen_interactions_from_exclusions(false, ns_counter, oracle._interaction_type, configs, estimators, 0);
BOOST_CHECK_EQUAL(champ_interactions.size(), 3);

const std::vector<namespace_index> first = {'A', 'A'};
Expand All @@ -474,7 +474,7 @@ BOOST_AUTO_TEST_CASE(one_diff_impl_unittest)
champ_interactions[2].begin(), champ_interactions[2].end(), third.begin(), third.end());

BOOST_CHECK_EQUAL(configs.size(), 1);
oracle.gen_exclusion_configs(estimators, CHAMP);
oracle.gen_exclusion_configs(estimators[CHAMP].first.live_interactions);
BOOST_CHECK_EQUAL(configs.size(), 4);
BOOST_CHECK_EQUAL(prio_queue.size(), 3);

Expand All @@ -498,7 +498,7 @@ BOOST_AUTO_TEST_CASE(one_diff_impl_unittest)
interaction_config_manager<config_oracle<one_diff_impl>>::apply_config_at_slot(estimators, oracle.configs, i,
interaction_config_manager<config_oracle<one_diff_impl>>::choose(oracle.index_queue),
aml->cm->automl_significance_level, aml->cm->automl_estimator_decay, 1);
gen_interactions(false, ns_counter, oracle._interaction_type, configs, estimators, i);
gen_interactions_from_exclusions(false, ns_counter, oracle._interaction_type, configs, estimators, i);
}
BOOST_CHECK_EQUAL(prio_queue.size(), 0);
BOOST_CHECK_EQUAL(estimators.size(), 4);
Expand Down
19 changes: 10 additions & 9 deletions vowpalwabbit/core/src/reductions/details/automl/automl_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ void interaction_config_manager<config_oracle_impl>::schedule()
// copy the weights of the champ to the new slot
weights.move_offsets(current_champ, live_slot, wpp);
// Regenerate interactions each time an exclusion is swapped in
gen_interactions(_ccb_on, ns_counter, interaction_type, _config_oracle.configs, estimators, live_slot);
gen_interactions_from_exclusions(
_ccb_on, ns_counter, interaction_type, _config_oracle.configs, estimators, live_slot);
}
}
}
Expand Down Expand Up @@ -265,23 +266,23 @@ void interaction_config_manager<config_oracle_impl>::apply_new_champ(config_orac
const uint64_t winning_challenger_slot, estimator_vec_t& estimators, const uint64_t priority_challengers,
const bool lb_trick)
{
const uint64_t old_champ_slot = 0;
const uint64_t champ_slot = 0;

while (!config_oracle.index_queue.empty()) { config_oracle.index_queue.pop(); };

estimators[winning_challenger_slot].first.eligible_to_inactivate = false;
if (priority_challengers > 1) { estimators[old_champ_slot].first.eligible_to_inactivate = false; }
if (priority_challengers > 1) { estimators[champ_slot].first.eligible_to_inactivate = false; }

auto winner_config_index = estimators[winning_challenger_slot].first.config_index;
std::swap(config_oracle.configs[0], config_oracle.configs[winner_config_index]);
if (winner_config_index != 1) { std::swap(config_oracle.configs[1], config_oracle.configs[winner_config_index]); }
config_oracle.valid_config_size = 2;

estimators[winning_challenger_slot].first.config_index = 0;
estimators[old_champ_slot].first.config_index = 1;
estimators[champ_slot].first.config_index = 1;

auto champ_estimator = std::move(estimators[winning_challenger_slot]);
auto old_champ_estimator = std::move(estimators[old_champ_slot]);
auto old_champ_estimator = std::move(estimators[champ_slot]);

estimators.clear();

Expand Down Expand Up @@ -309,7 +310,7 @@ void interaction_config_manager<config_oracle_impl>::apply_new_champ(config_orac
estimators[1].second.reset_stats();
}

config_oracle.gen_exclusion_configs(estimators, old_champ_slot);
config_oracle.gen_exclusion_configs(estimators[champ_slot].first.live_interactions);
}

template <typename config_oracle_impl>
Expand Down Expand Up @@ -350,11 +351,11 @@ void automl<CMType>::one_step(multi_learner& base, multi_ex& ec, CB::cb_class& l
{
for (uint64_t live_slot = 0; live_slot < cm->estimators.size(); ++live_slot)
{
gen_interactions(
gen_interactions_from_exclusions(
cm->_ccb_on, cm->ns_counter, cm->interaction_type, cm->_config_oracle.configs, cm->estimators, live_slot);
}
}
cm->_config_oracle.gen_exclusion_configs(cm->estimators, cm->current_champ);
cm->_config_oracle.gen_exclusion_configs(cm->estimators[cm->current_champ].first.live_interactions);
offset_learn(base, ec, logged, labelled_action);
current_state = automl_state::Experimenting;
break;
Expand All @@ -366,7 +367,7 @@ void automl<CMType>::one_step(multi_learner& base, multi_ex& ec, CB::cb_class& l
{
for (uint64_t live_slot = 0; live_slot < cm->estimators.size(); ++live_slot)
{
gen_interactions(
gen_interactions_from_exclusions(
cm->_ccb_on, cm->ns_counter, cm->interaction_type, cm->_config_oracle.configs, cm->estimators, live_slot);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ size_t read_model_field(io_buf& io, VW::reductions::automl::interaction_config_m
bytes += read_model_field(io, cm.per_live_model_state_uint64);
for (uint64_t live_slot = 0; live_slot < cm.estimators.size(); ++live_slot)
{
gen_interactions(
gen_interactions_from_exclusions(
cm._ccb_on, cm.ns_counter, cm.interaction_type, cm._config_oracle.configs, cm.estimators, live_slot);
}
return bytes;
Expand Down
35 changes: 15 additions & 20 deletions vowpalwabbit/core/src/reductions/details/automl/automl_oracle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,14 @@ void config_oracle<oracle_impl>::insert_config(std::set<std::vector<namespace_in
++valid_config_size;
}

void oracle_rand_impl::gen_exclusion_configs(
config_oracle<oracle_rand_impl>* co, estimator_vec_t& estimators, const uint64_t current_champ)
void oracle_rand_impl::gen_exclusion_configs(config_oracle<oracle_rand_impl>* co,
const interaction_vec_t& champ_interactions, std::vector<exclusion_config>& configs)
{
auto& champ_interactions = estimators[current_champ].first.live_interactions;
const uint64_t champ_index = 0;
for (uint64_t i = 0; i < CONFIGS_PER_CHAMP_CHANGE; ++i)
{
uint64_t rand_ind = static_cast<uint64_t>(random_state->get_and_update_random() * champ_interactions.size());
std::set<std::vector<namespace_index>> new_exclusions(
co->configs[estimators[current_champ].first.config_index].exclusions);
std::set<std::vector<namespace_index>> new_exclusions(configs[champ_index].exclusions);
if (co->_interaction_type == "quadratic")
{
namespace_index ns1 = champ_interactions[rand_ind][0];
Expand All @@ -115,15 +114,14 @@ void oracle_rand_impl::gen_exclusion_configs(
co->insert_config(std::move(new_exclusions));
}
}
void one_diff_impl::gen_exclusion_configs(
config_oracle<one_diff_impl>* co, estimator_vec_t& estimators, const uint64_t current_champ)
void one_diff_impl::gen_exclusion_configs(config_oracle<one_diff_impl>* co, const interaction_vec_t& champ_interactions,
std::vector<exclusion_config>& configs)
{
auto& champ_interactions = estimators[current_champ].first.live_interactions;
const uint64_t champ_index = 0;
// Add one exclusion (for each interaction)
for (auto& interaction : champ_interactions)
{
std::set<std::vector<namespace_index>> new_exclusions(
co->configs[estimators[current_champ].first.config_index].exclusions);
std::set<std::vector<namespace_index>> new_exclusions(configs[champ_index].exclusions);
if (co->_interaction_type == "quadratic")
{
namespace_index ns1 = interaction[0];
Expand All @@ -149,22 +147,19 @@ void one_diff_impl::gen_exclusion_configs(
co->insert_config(std::move(new_exclusions));
}
// Remove one exclusion (for each exclusion)
for (auto& ns_pair : co->configs[estimators[current_champ].first.config_index].exclusions)
for (auto& ns_pair : configs[champ_index].exclusions)
{
auto new_exclusions = co->configs[estimators[current_champ].first.config_index].exclusions;
auto new_exclusions = configs[champ_index].exclusions;
new_exclusions.erase(ns_pair);
co->insert_config(std::move(new_exclusions));
}
}
void champdupe_impl::gen_exclusion_configs(
config_oracle<champdupe_impl>* co, estimator_vec_t& estimators, const uint64_t current_champ)
config_oracle<champdupe_impl>* co, const interaction_vec_t&, std::vector<exclusion_config>& configs)
{
const uint64_t champ_index = 0;
for (uint64_t i = 0; co->configs.size() <= 2; ++i)
{
co->insert_config(
std::set<std::vector<namespace_index>>(co->configs[estimators[current_champ].first.config_index].exclusions),
true);
}
{ co->insert_config(std::set<std::vector<namespace_index>>(configs[champ_index].exclusions), true); }
}

// This will generate configs based on the current champ. These configs will be
Expand All @@ -173,9 +168,9 @@ void champdupe_impl::gen_exclusion_configs(
// of configs to generate per champ is hard-coded to 5 at the moment.
// TODO: Add logic to avoid duplicate configs (could be very costly)
template <typename oracle_impl>
void config_oracle<oracle_impl>::gen_exclusion_configs(estimator_vec_t& estimators, const uint64_t current_champ)
void config_oracle<oracle_impl>::gen_exclusion_configs(const interaction_vec_t& champ_interactions)
{
_impl.gen_exclusion_configs(this, estimators, current_champ);
_impl.gen_exclusion_configs(this, champ_interactions, configs);
}

// This function is triggered when all sets of interactions generated by the oracle have been tried and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ namespace automl
// from the corresponding live_slot. This function can be swapped out depending on
// preference of how to generate interactions from a given set of exclusions.
// Transforms exclusions -> interactions expected by VW.
void gen_interactions(bool ccb_on, std::map<namespace_index, uint64_t>& ns_counter, std::string& interaction_type,
std::vector<exclusion_config>& configs, estimator_vec_t& estimators, uint64_t live_slot)
void gen_interactions_from_exclusions(bool ccb_on, std::map<namespace_index, uint64_t>& ns_counter,
std::string& interaction_type, std::vector<exclusion_config>& configs, estimator_vec_t& estimators,
uint64_t live_slot)
{
if (interaction_type == "quadratic")
{
Expand Down
19 changes: 10 additions & 9 deletions vowpalwabbit/core/src/reductions/details/automl_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ struct config_oracle
config_oracle(uint64_t global_lease, priority_func* calc_priority, std::map<namespace_index, uint64_t>& ns_counter,
const std::string& interaction_type, const std::string& oracle_type, std::shared_ptr<VW::rand_state>& rand_state);

void gen_exclusion_configs(estimator_vec_t& estimators, const uint64_t current_champ);
void gen_exclusion_configs(const interaction_vec_t& champ_interactions);
void insert_config(std::set<std::vector<namespace_index>>&& new_exclusions, bool allow_dups = false);
bool repopulate_index_queue();
void insert_qcolcol();
Expand All @@ -116,18 +116,18 @@ struct oracle_rand_impl
{
std::shared_ptr<VW::rand_state> random_state;
oracle_rand_impl(std::shared_ptr<VW::rand_state> random_state) : random_state(std::move(random_state)) {}
void gen_exclusion_configs(
config_oracle<oracle_rand_impl>* config_oracle, estimator_vec_t& estimators, const uint64_t current_champ);
void gen_exclusion_configs(config_oracle<oracle_rand_impl>* co, const interaction_vec_t& champ_interactions,
std::vector<exclusion_config>& configs);
};
struct one_diff_impl
{
void gen_exclusion_configs(
config_oracle<one_diff_impl>* config_oracle, estimator_vec_t& estimators, const uint64_t current_champ);
void gen_exclusion_configs(config_oracle<one_diff_impl>* co, const interaction_vec_t& champ_interactions,
std::vector<exclusion_config>& configs);
};
struct champdupe_impl
{
void gen_exclusion_configs(
config_oracle<champdupe_impl>* config_oracle, estimator_vec_t& estimators, const uint64_t current_champ);
void gen_exclusion_configs(config_oracle<champdupe_impl>* co, const interaction_vec_t& champ_interactions,
std::vector<exclusion_config>& configs);
};

template <typename config_oracle_impl>
Expand Down Expand Up @@ -191,8 +191,9 @@ struct interaction_config_manager : config_manager
};

bool count_namespaces(const multi_ex& ecs, std::map<namespace_index, uint64_t>& ns_counter);
void gen_interactions(bool ccb_on, std::map<namespace_index, uint64_t>& ns_counter, std::string& interaction_type,
std::vector<exclusion_config>& configs, estimator_vec_t& estimators, uint64_t live_slot);
void gen_interactions_from_exclusions(bool ccb_on, std::map<namespace_index, uint64_t>& ns_counter,
std::string& interaction_type, std::vector<exclusion_config>& configs, estimator_vec_t& estimators,
uint64_t live_slot);
void apply_config(example* ec, interaction_vec_t* live_interactions);
bool is_allowed_to_remove(const unsigned char ns);
void clear_non_champ_weights(dense_parameters& weights, uint32_t total, uint32_t& wpp);
Expand Down