Skip to content

Commit

Permalink
refactor: [automl] remove estimator dependency from oracle (#4117)
Browse files Browse the repository at this point in the history
- also remove estimator dependency from gen interactions func
  • Loading branch information
lalo authored Aug 15, 2022
1 parent d811ea7 commit ad35c3b
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 52 deletions.
10 changes: 7 additions & 3 deletions test/unit_test/automl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,9 @@ BOOST_AUTO_TEST_CASE(one_diff_impl_unittest)
auto& champ_interactions = estimators[CHAMP].first.live_interactions;

BOOST_CHECK_EQUAL(champ_interactions.size(), 0);
gen_interactions(false, ns_counter, oracle._interaction_type, configs, estimators, 0);
auto& exclusions = oracle.configs[estimators[0].first.config_index].exclusions;
auto& interactions = estimators[0].first.live_interactions;
gen_interactions_from_exclusions(false, ns_counter, oracle._interaction_type, exclusions, interactions);
BOOST_CHECK_EQUAL(champ_interactions.size(), 3);

const std::vector<namespace_index> first = {'A', 'A'};
Expand All @@ -474,7 +476,7 @@ BOOST_AUTO_TEST_CASE(one_diff_impl_unittest)
champ_interactions[2].begin(), champ_interactions[2].end(), third.begin(), third.end());

BOOST_CHECK_EQUAL(configs.size(), 1);
oracle.gen_exclusion_configs(estimators, CHAMP);
oracle.gen_exclusion_configs(estimators[CHAMP].first.live_interactions);
BOOST_CHECK_EQUAL(configs.size(), 4);
BOOST_CHECK_EQUAL(prio_queue.size(), 3);

Expand All @@ -498,7 +500,9 @@ BOOST_AUTO_TEST_CASE(one_diff_impl_unittest)
interaction_config_manager<config_oracle<one_diff_impl>>::apply_config_at_slot(estimators, oracle.configs, i,
interaction_config_manager<config_oracle<one_diff_impl>>::choose(oracle.index_queue),
aml->cm->automl_significance_level, aml->cm->automl_estimator_decay, 1);
gen_interactions(false, ns_counter, oracle._interaction_type, configs, estimators, i);
auto& temp_exclusions = oracle.configs[estimators[i].first.config_index].exclusions;
auto& temp_interactions = estimators[i].first.live_interactions;
gen_interactions_from_exclusions(false, ns_counter, oracle._interaction_type, temp_exclusions, temp_interactions);
}
BOOST_CHECK_EQUAL(prio_queue.size(), 0);
BOOST_CHECK_EQUAL(estimators.size(), 4);
Expand Down
26 changes: 15 additions & 11 deletions vowpalwabbit/core/src/reductions/details/automl/automl_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ void interaction_config_manager<config_oracle_impl>::schedule()
// copy the weights of the champ to the new slot
weights.move_offsets(current_champ, live_slot, wpp);
// Regenerate interactions each time an exclusion is swapped in
gen_interactions(_ccb_on, ns_counter, interaction_type, _config_oracle.configs, estimators, live_slot);
gen_interactions_from_exclusions(_ccb_on, ns_counter, interaction_type,
_config_oracle.configs[estimators[live_slot].first.config_index].exclusions,
estimators[live_slot].first.live_interactions);
}
}
}
Expand Down Expand Up @@ -265,23 +267,23 @@ void interaction_config_manager<config_oracle_impl>::apply_new_champ(config_orac
const uint64_t winning_challenger_slot, estimator_vec_t& estimators, const uint64_t priority_challengers,
const bool lb_trick)
{
const uint64_t old_champ_slot = 0;
const uint64_t champ_slot = 0;

while (!config_oracle.index_queue.empty()) { config_oracle.index_queue.pop(); };

estimators[winning_challenger_slot].first.eligible_to_inactivate = false;
if (priority_challengers > 1) { estimators[old_champ_slot].first.eligible_to_inactivate = false; }
if (priority_challengers > 1) { estimators[champ_slot].first.eligible_to_inactivate = false; }

auto winner_config_index = estimators[winning_challenger_slot].first.config_index;
std::swap(config_oracle.configs[0], config_oracle.configs[winner_config_index]);
if (winner_config_index != 1) { std::swap(config_oracle.configs[1], config_oracle.configs[winner_config_index]); }
config_oracle.valid_config_size = 2;

estimators[winning_challenger_slot].first.config_index = 0;
estimators[old_champ_slot].first.config_index = 1;
estimators[champ_slot].first.config_index = 1;

auto champ_estimator = std::move(estimators[winning_challenger_slot]);
auto old_champ_estimator = std::move(estimators[old_champ_slot]);
auto old_champ_estimator = std::move(estimators[champ_slot]);

estimators.clear();

Expand Down Expand Up @@ -309,7 +311,7 @@ void interaction_config_manager<config_oracle_impl>::apply_new_champ(config_orac
estimators[1].second.reset_stats();
}

config_oracle.gen_exclusion_configs(estimators, old_champ_slot);
config_oracle.gen_exclusion_configs(estimators[champ_slot].first.live_interactions);
}

template <typename config_oracle_impl>
Expand Down Expand Up @@ -350,11 +352,12 @@ void automl<CMType>::one_step(multi_learner& base, multi_ex& ec, CB::cb_class& l
{
for (uint64_t live_slot = 0; live_slot < cm->estimators.size(); ++live_slot)
{
gen_interactions(
cm->_ccb_on, cm->ns_counter, cm->interaction_type, cm->_config_oracle.configs, cm->estimators, live_slot);
auto& exclusions = cm->_config_oracle.configs[cm->estimators[live_slot].first.config_index].exclusions;
auto& interactions = cm->estimators[live_slot].first.live_interactions;
gen_interactions_from_exclusions(cm->_ccb_on, cm->ns_counter, cm->interaction_type, exclusions, interactions);
}
}
cm->_config_oracle.gen_exclusion_configs(cm->estimators, cm->current_champ);
cm->_config_oracle.gen_exclusion_configs(cm->estimators[cm->current_champ].first.live_interactions);
offset_learn(base, ec, logged, labelled_action);
current_state = automl_state::Experimenting;
break;
Expand All @@ -366,8 +369,9 @@ void automl<CMType>::one_step(multi_learner& base, multi_ex& ec, CB::cb_class& l
{
for (uint64_t live_slot = 0; live_slot < cm->estimators.size(); ++live_slot)
{
gen_interactions(
cm->_ccb_on, cm->ns_counter, cm->interaction_type, cm->_config_oracle.configs, cm->estimators, live_slot);
auto& exclusions = cm->_config_oracle.configs[cm->estimators[live_slot].first.config_index].exclusions;
auto& interactions = cm->estimators[live_slot].first.live_interactions;
gen_interactions_from_exclusions(cm->_ccb_on, cm->ns_counter, cm->interaction_type, exclusions, interactions);
}
}
cm->schedule();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@ size_t read_model_field(io_buf& io, VW::reductions::automl::interaction_config_m
bytes += read_model_field(io, cm.per_live_model_state_uint64);
for (uint64_t live_slot = 0; live_slot < cm.estimators.size(); ++live_slot)
{
gen_interactions(
cm._ccb_on, cm.ns_counter, cm.interaction_type, cm._config_oracle.configs, cm.estimators, live_slot);
auto& exclusions = cm._config_oracle.configs[cm.estimators[live_slot].first.config_index].exclusions;
auto& interactions = cm.estimators[live_slot].first.live_interactions;
reductions::automl::gen_interactions_from_exclusions(
cm._ccb_on, cm.ns_counter, cm.interaction_type, exclusions, interactions);
}
return bytes;
}
Expand Down
35 changes: 15 additions & 20 deletions vowpalwabbit/core/src/reductions/details/automl/automl_oracle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,14 @@ void config_oracle<oracle_impl>::insert_config(std::set<std::vector<namespace_in
++valid_config_size;
}

void oracle_rand_impl::gen_exclusion_configs(
config_oracle<oracle_rand_impl>* co, estimator_vec_t& estimators, const uint64_t current_champ)
void oracle_rand_impl::gen_exclusion_configs(config_oracle<oracle_rand_impl>* co,
const interaction_vec_t& champ_interactions, std::vector<exclusion_config>& configs)
{
auto& champ_interactions = estimators[current_champ].first.live_interactions;
const uint64_t champ_index = 0;
for (uint64_t i = 0; i < CONFIGS_PER_CHAMP_CHANGE; ++i)
{
uint64_t rand_ind = static_cast<uint64_t>(random_state->get_and_update_random() * champ_interactions.size());
std::set<std::vector<namespace_index>> new_exclusions(
co->configs[estimators[current_champ].first.config_index].exclusions);
std::set<std::vector<namespace_index>> new_exclusions(configs[champ_index].exclusions);
if (co->_interaction_type == "quadratic")
{
namespace_index ns1 = champ_interactions[rand_ind][0];
Expand All @@ -115,15 +114,14 @@ void oracle_rand_impl::gen_exclusion_configs(
co->insert_config(std::move(new_exclusions));
}
}
void one_diff_impl::gen_exclusion_configs(
config_oracle<one_diff_impl>* co, estimator_vec_t& estimators, const uint64_t current_champ)
void one_diff_impl::gen_exclusion_configs(config_oracle<one_diff_impl>* co, const interaction_vec_t& champ_interactions,
std::vector<exclusion_config>& configs)
{
auto& champ_interactions = estimators[current_champ].first.live_interactions;
const uint64_t champ_index = 0;
// Add one exclusion (for each interaction)
for (auto& interaction : champ_interactions)
{
std::set<std::vector<namespace_index>> new_exclusions(
co->configs[estimators[current_champ].first.config_index].exclusions);
std::set<std::vector<namespace_index>> new_exclusions(configs[champ_index].exclusions);
if (co->_interaction_type == "quadratic")
{
namespace_index ns1 = interaction[0];
Expand All @@ -149,22 +147,19 @@ void one_diff_impl::gen_exclusion_configs(
co->insert_config(std::move(new_exclusions));
}
// Remove one exclusion (for each exclusion)
for (auto& ns_pair : co->configs[estimators[current_champ].first.config_index].exclusions)
for (auto& ns_pair : configs[champ_index].exclusions)
{
auto new_exclusions = co->configs[estimators[current_champ].first.config_index].exclusions;
auto new_exclusions = configs[champ_index].exclusions;
new_exclusions.erase(ns_pair);
co->insert_config(std::move(new_exclusions));
}
}
void champdupe_impl::gen_exclusion_configs(
config_oracle<champdupe_impl>* co, estimator_vec_t& estimators, const uint64_t current_champ)
config_oracle<champdupe_impl>* co, const interaction_vec_t&, std::vector<exclusion_config>& configs)
{
const uint64_t champ_index = 0;
for (uint64_t i = 0; co->configs.size() <= 2; ++i)
{
co->insert_config(
std::set<std::vector<namespace_index>>(co->configs[estimators[current_champ].first.config_index].exclusions),
true);
}
{ co->insert_config(std::set<std::vector<namespace_index>>(configs[champ_index].exclusions), true); }
}

// This will generate configs based on the current champ. These configs will be
Expand All @@ -173,9 +168,9 @@ void champdupe_impl::gen_exclusion_configs(
// of configs to generate per champ is hard-coded to 5 at the moment.
// TODO: Add logic to avoid duplicate configs (could be very costly)
template <typename oracle_impl>
void config_oracle<oracle_impl>::gen_exclusion_configs(estimator_vec_t& estimators, const uint64_t current_champ)
void config_oracle<oracle_impl>::gen_exclusion_configs(const interaction_vec_t& champ_interactions)
{
_impl.gen_exclusion_configs(this, estimators, current_champ);
_impl.gen_exclusion_configs(this, champ_interactions, configs);
}

// This function is triggered when all sets of interactions generated by the oracle have been tried and
Expand Down
10 changes: 3 additions & 7 deletions vowpalwabbit/core/src/reductions/details/automl/automl_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@ namespace automl
// from the corresponding live_slot. This function can be swapped out depending on
// preference of how to generate interactions from a given set of exclusions.
// Transforms exclusions -> interactions expected by VW.
void gen_interactions(bool ccb_on, std::map<namespace_index, uint64_t>& ns_counter, std::string& interaction_type,
std::vector<exclusion_config>& configs, estimator_vec_t& estimators, uint64_t live_slot)
void gen_interactions_from_exclusions(const bool ccb_on, const std::map<namespace_index, uint64_t>& ns_counter,
const std::string& interaction_type, const std::set<std::vector<namespace_index>>& exclusions,
interaction_vec_t& interactions)
{
if (interaction_type == "quadratic")
{
auto& exclusions = configs[estimators[live_slot].first.config_index].exclusions;
auto& interactions = estimators[live_slot].first.live_interactions;
if (!interactions.empty()) { interactions.clear(); }
for (auto it = ns_counter.begin(); it != ns_counter.end(); ++it)
{
Expand All @@ -39,8 +38,6 @@ void gen_interactions(bool ccb_on, std::map<namespace_index, uint64_t>& ns_count
}
else if (interaction_type == "cubic")
{
auto& exclusions = configs[estimators[live_slot].first.config_index].exclusions;
auto& interactions = estimators[live_slot].first.live_interactions;
if (!interactions.empty()) { interactions.clear(); }
for (auto it = ns_counter.begin(); it != ns_counter.end(); ++it)
{
Expand All @@ -65,7 +62,6 @@ void gen_interactions(bool ccb_on, std::map<namespace_index, uint64_t>& ns_count
if (ccb_on)
{
std::vector<std::vector<extent_term>> empty;
auto& interactions = estimators[live_slot].first.live_interactions;
ccb::insert_ccb_interactions(interactions, empty);
}
}
Expand Down
19 changes: 10 additions & 9 deletions vowpalwabbit/core/src/reductions/details/automl_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ struct config_oracle
config_oracle(uint64_t global_lease, priority_func* calc_priority, std::map<namespace_index, uint64_t>& ns_counter,
const std::string& interaction_type, const std::string& oracle_type, std::shared_ptr<VW::rand_state>& rand_state);

void gen_exclusion_configs(estimator_vec_t& estimators, const uint64_t current_champ);
void gen_exclusion_configs(const interaction_vec_t& champ_interactions);
void insert_config(std::set<std::vector<namespace_index>>&& new_exclusions, bool allow_dups = false);
bool repopulate_index_queue();
void insert_qcolcol();
Expand All @@ -116,18 +116,18 @@ struct oracle_rand_impl
{
std::shared_ptr<VW::rand_state> random_state;
oracle_rand_impl(std::shared_ptr<VW::rand_state> random_state) : random_state(std::move(random_state)) {}
void gen_exclusion_configs(
config_oracle<oracle_rand_impl>* config_oracle, estimator_vec_t& estimators, const uint64_t current_champ);
void gen_exclusion_configs(config_oracle<oracle_rand_impl>* co, const interaction_vec_t& champ_interactions,
std::vector<exclusion_config>& configs);
};
struct one_diff_impl
{
void gen_exclusion_configs(
config_oracle<one_diff_impl>* config_oracle, estimator_vec_t& estimators, const uint64_t current_champ);
void gen_exclusion_configs(config_oracle<one_diff_impl>* co, const interaction_vec_t& champ_interactions,
std::vector<exclusion_config>& configs);
};
struct champdupe_impl
{
void gen_exclusion_configs(
config_oracle<champdupe_impl>* config_oracle, estimator_vec_t& estimators, const uint64_t current_champ);
void gen_exclusion_configs(config_oracle<champdupe_impl>* co, const interaction_vec_t& champ_interactions,
std::vector<exclusion_config>& configs);
};

template <typename config_oracle_impl>
Expand Down Expand Up @@ -191,8 +191,9 @@ struct interaction_config_manager : config_manager
};

bool count_namespaces(const multi_ex& ecs, std::map<namespace_index, uint64_t>& ns_counter);
void gen_interactions(bool ccb_on, std::map<namespace_index, uint64_t>& ns_counter, std::string& interaction_type,
std::vector<exclusion_config>& configs, estimator_vec_t& estimators, uint64_t live_slot);
void gen_interactions_from_exclusions(const bool ccb_on, const std::map<namespace_index, uint64_t>& ns_counter,
const std::string& interaction_type, const std::set<std::vector<namespace_index>>& exclusions,
interaction_vec_t& interactions);
void apply_config(example* ec, interaction_vec_t* live_interactions);
bool is_allowed_to_remove(const unsigned char ns);
void clear_non_champ_weights(dense_parameters& weights, uint32_t total, uint32_t& wpp);
Expand Down

0 comments on commit ad35c3b

Please sign in to comment.