Skip to content

Commit

Permalink
refactor: [las] remove estimator dependency from oracle
Browse files Browse the repository at this point in the history
  • Loading branch information
lalo committed Aug 15, 2022
1 parent 4c6fdb8 commit 4b9a224
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 35 deletions.
2 changes: 1 addition & 1 deletion test/unit_test/automl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ BOOST_AUTO_TEST_CASE(one_diff_impl_unittest)
champ_interactions[2].begin(), champ_interactions[2].end(), third.begin(), third.end());

BOOST_CHECK_EQUAL(configs.size(), 1);
oracle.gen_exclusion_configs(estimators, CHAMP);
oracle.gen_exclusion_configs(estimators[CHAMP].first.live_interactions);
BOOST_CHECK_EQUAL(configs.size(), 4);
BOOST_CHECK_EQUAL(prio_queue.size(), 3);

Expand Down
12 changes: 6 additions & 6 deletions vowpalwabbit/core/src/reductions/details/automl/automl_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,23 +265,23 @@ void interaction_config_manager<config_oracle_impl>::apply_new_champ(config_orac
const uint64_t winning_challenger_slot, estimator_vec_t& estimators, const uint64_t priority_challengers,
const bool lb_trick)
{
const uint64_t old_champ_slot = 0;
const uint64_t champ_slot = 0;

while (!config_oracle.index_queue.empty()) { config_oracle.index_queue.pop(); };

estimators[winning_challenger_slot].first.eligible_to_inactivate = false;
if (priority_challengers > 1) { estimators[old_champ_slot].first.eligible_to_inactivate = false; }
if (priority_challengers > 1) { estimators[champ_slot].first.eligible_to_inactivate = false; }

auto winner_config_index = estimators[winning_challenger_slot].first.config_index;
std::swap(config_oracle.configs[0], config_oracle.configs[winner_config_index]);
if (winner_config_index != 1) { std::swap(config_oracle.configs[1], config_oracle.configs[winner_config_index]); }
config_oracle.valid_config_size = 2;

estimators[winning_challenger_slot].first.config_index = 0;
estimators[old_champ_slot].first.config_index = 1;
estimators[champ_slot].first.config_index = 1;

auto champ_estimator = std::move(estimators[winning_challenger_slot]);
auto old_champ_estimator = std::move(estimators[old_champ_slot]);
auto old_champ_estimator = std::move(estimators[champ_slot]);

estimators.clear();

Expand Down Expand Up @@ -309,7 +309,7 @@ void interaction_config_manager<config_oracle_impl>::apply_new_champ(config_orac
estimators[1].second.reset_stats();
}

config_oracle.gen_exclusion_configs(estimators, old_champ_slot);
config_oracle.gen_exclusion_configs(estimators[champ_slot].first.live_interactions);
}

template <typename config_oracle_impl>
Expand Down Expand Up @@ -354,7 +354,7 @@ void automl<CMType>::one_step(multi_learner& base, multi_ex& ec, CB::cb_class& l
cm->_ccb_on, cm->ns_counter, cm->interaction_type, cm->_config_oracle.configs, cm->estimators, live_slot);
}
}
cm->_config_oracle.gen_exclusion_configs(cm->estimators, cm->current_champ);
cm->_config_oracle.gen_exclusion_configs(cm->estimators[cm->current_champ].first.live_interactions);
offset_learn(base, ec, logged, labelled_action);
current_state = automl_state::Experimenting;
break;
Expand Down
37 changes: 16 additions & 21 deletions vowpalwabbit/core/src/reductions/details/automl/automl_oracle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,14 @@ void config_oracle<oracle_impl>::insert_config(std::set<std::vector<namespace_in
++valid_config_size;
}

void oracle_rand_impl::gen_exclusion_configs(
config_oracle<oracle_rand_impl>* co, estimator_vec_t& estimators, const uint64_t current_champ)
void oracle_rand_impl::gen_exclusion_configs(config_oracle<oracle_rand_impl>* co,
const interaction_vec_t& champ_interactions, std::vector<exclusion_config>& configs)
{
auto& champ_interactions = estimators[current_champ].first.live_interactions;
const uint64_t champ_index = 0;
for (uint64_t i = 0; i < CONFIGS_PER_CHAMP_CHANGE; ++i)
{
uint64_t rand_ind = static_cast<uint64_t>(random_state->get_and_update_random() * champ_interactions.size());
std::set<std::vector<namespace_index>> new_exclusions(
co->configs[estimators[current_champ].first.config_index].exclusions);
std::set<std::vector<namespace_index>> new_exclusions(configs[champ_index].exclusions);
if (co->_interaction_type == "quadratic")
{
namespace_index ns1 = champ_interactions[rand_ind][0];
Expand All @@ -115,15 +114,14 @@ void oracle_rand_impl::gen_exclusion_configs(
co->insert_config(std::move(new_exclusions));
}
}
void one_diff_impl::gen_exclusion_configs(
config_oracle<one_diff_impl>* co, estimator_vec_t& estimators, const uint64_t current_champ)
void one_diff_impl::gen_exclusion_configs(config_oracle<one_diff_impl>* co, const interaction_vec_t& champ_interactions,
std::vector<exclusion_config>& configs)
{
auto& champ_interactions = estimators[current_champ].first.live_interactions;
const uint64_t champ_index = 0;
// Add one exclusion (for each interaction)
for (auto& interaction : champ_interactions)
{
std::set<std::vector<namespace_index>> new_exclusions(
co->configs[estimators[current_champ].first.config_index].exclusions);
std::set<std::vector<namespace_index>> new_exclusions(configs[champ_index].exclusions);
if (co->_interaction_type == "quadratic")
{
namespace_index ns1 = interaction[0];
Expand All @@ -149,22 +147,19 @@ void one_diff_impl::gen_exclusion_configs(
co->insert_config(std::move(new_exclusions));
}
// Remove one exclusion (for each exclusion)
for (auto& ns_pair : co->configs[estimators[current_champ].first.config_index].exclusions)
for (auto& ns_pair : configs[champ_index].exclusions)
{
auto new_exclusions = co->configs[estimators[current_champ].first.config_index].exclusions;
auto new_exclusions = configs[champ_index].exclusions;
new_exclusions.erase(ns_pair);
co->insert_config(std::move(new_exclusions));
}
}
void champdupe_impl::gen_exclusion_configs(
config_oracle<champdupe_impl>* co, estimator_vec_t& estimators, const uint64_t current_champ)
void champdupe_impl::gen_exclusion_configs(config_oracle<champdupe_impl>* co,
const interaction_vec_t& champ_interactions, std::vector<exclusion_config>& configs)
{
const uint64_t champ_index = 0;
for (uint64_t i = 0; co->configs.size() <= 2; ++i)
{
co->insert_config(
std::set<std::vector<namespace_index>>(co->configs[estimators[current_champ].first.config_index].exclusions),
true);
}
{ co->insert_config(std::set<std::vector<namespace_index>>(configs[champ_index].exclusions), true); }
}

// This will generate configs based on the current champ. These configs will be
Expand All @@ -173,9 +168,9 @@ void champdupe_impl::gen_exclusion_configs(
// of configs to generate per champ is hard-coded to 5 at the moment.
// TODO: Add logic to avoid duplicate configs (could be very costly)
template <typename oracle_impl>
void config_oracle<oracle_impl>::gen_exclusion_configs(estimator_vec_t& estimators, const uint64_t current_champ)
void config_oracle<oracle_impl>::gen_exclusion_configs(const interaction_vec_t& champ_interactions)
{
_impl.gen_exclusion_configs(this, estimators, current_champ);
_impl.gen_exclusion_configs(this, champ_interactions, configs);
}

// This function is triggered when all sets of interactions generated by the oracle have been tried and
Expand Down
14 changes: 7 additions & 7 deletions vowpalwabbit/core/src/reductions/details/automl_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ struct config_oracle
config_oracle(uint64_t global_lease, priority_func* calc_priority, std::map<namespace_index, uint64_t>& ns_counter,
const std::string& interaction_type, const std::string& oracle_type, std::shared_ptr<VW::rand_state>& rand_state);

void gen_exclusion_configs(estimator_vec_t& estimators, const uint64_t current_champ);
void gen_exclusion_configs(const interaction_vec_t& champ_interactions);
void insert_config(std::set<std::vector<namespace_index>>&& new_exclusions, bool allow_dups = false);
bool repopulate_index_queue();
void insert_qcolcol();
Expand All @@ -116,18 +116,18 @@ struct oracle_rand_impl
{
std::shared_ptr<VW::rand_state> random_state;
oracle_rand_impl(std::shared_ptr<VW::rand_state> random_state) : random_state(std::move(random_state)) {}
void gen_exclusion_configs(
config_oracle<oracle_rand_impl>* config_oracle, estimator_vec_t& estimators, const uint64_t current_champ);
void gen_exclusion_configs(config_oracle<oracle_rand_impl>* co, const interaction_vec_t& champ_interactions,
std::vector<exclusion_config>& configs);
};
struct one_diff_impl
{
void gen_exclusion_configs(
config_oracle<one_diff_impl>* config_oracle, estimator_vec_t& estimators, const uint64_t current_champ);
void gen_exclusion_configs(config_oracle<one_diff_impl>* co, const interaction_vec_t& champ_interactions,
std::vector<exclusion_config>& configs);
};
struct champdupe_impl
{
void gen_exclusion_configs(
config_oracle<champdupe_impl>* config_oracle, estimator_vec_t& estimators, const uint64_t current_champ);
void gen_exclusion_configs(config_oracle<champdupe_impl>* co, const interaction_vec_t& champ_interactions,
std::vector<exclusion_config>& configs);
};

template <typename config_oracle_impl>
Expand Down

0 comments on commit 4b9a224

Please sign in to comment.