Skip to content

Commit

Permalink
remove estimator dependency from gen interactions
Browse files Browse the repository at this point in the history
  • Loading branch information
lalo committed Aug 15, 2022
1 parent bdfca2c commit 097db20
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 21 deletions.
8 changes: 6 additions & 2 deletions test/unit_test/automl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,9 @@ BOOST_AUTO_TEST_CASE(one_diff_impl_unittest)
auto& champ_interactions = estimators[CHAMP].first.live_interactions;

BOOST_CHECK_EQUAL(champ_interactions.size(), 0);
gen_interactions_from_exclusions(false, ns_counter, oracle._interaction_type, configs, estimators, 0);
auto& exclusions = oracle.configs[estimators[0].first.config_index].exclusions;
auto& interactions = estimators[0].first.live_interactions;
gen_interactions_from_exclusions(false, ns_counter, oracle._interaction_type, exclusions, interactions);
BOOST_CHECK_EQUAL(champ_interactions.size(), 3);

const std::vector<namespace_index> first = {'A', 'A'};
Expand Down Expand Up @@ -498,7 +500,9 @@ BOOST_AUTO_TEST_CASE(one_diff_impl_unittest)
interaction_config_manager<config_oracle<one_diff_impl>>::apply_config_at_slot(estimators, oracle.configs, i,
interaction_config_manager<config_oracle<one_diff_impl>>::choose(oracle.index_queue),
aml->cm->automl_significance_level, aml->cm->automl_estimator_decay, 1);
gen_interactions_from_exclusions(false, ns_counter, oracle._interaction_type, configs, estimators, i);
auto& temp_exclusions = oracle.configs[estimators[i].first.config_index].exclusions;
auto& temp_interactions = estimators[i].first.live_interactions;
gen_interactions_from_exclusions(false, ns_counter, oracle._interaction_type, temp_exclusions, temp_interactions);
}
BOOST_CHECK_EQUAL(prio_queue.size(), 0);
BOOST_CHECK_EQUAL(estimators.size(), 4);
Expand Down
15 changes: 9 additions & 6 deletions vowpalwabbit/core/src/reductions/details/automl/automl_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,9 @@ void interaction_config_manager<config_oracle_impl>::schedule()
// copy the weights of the champ to the new slot
weights.move_offsets(current_champ, live_slot, wpp);
// Regenerate interactions each time an exclusion is swapped in
gen_interactions_from_exclusions(
_ccb_on, ns_counter, interaction_type, _config_oracle.configs, estimators, live_slot);
gen_interactions_from_exclusions(_ccb_on, ns_counter, interaction_type,
_config_oracle.configs[estimators[live_slot].first.config_index].exclusions,
estimators[live_slot].first.live_interactions);
}
}
}
Expand Down Expand Up @@ -351,8 +352,9 @@ void automl<CMType>::one_step(multi_learner& base, multi_ex& ec, CB::cb_class& l
{
for (uint64_t live_slot = 0; live_slot < cm->estimators.size(); ++live_slot)
{
gen_interactions_from_exclusions(
cm->_ccb_on, cm->ns_counter, cm->interaction_type, cm->_config_oracle.configs, cm->estimators, live_slot);
auto& exclusions = cm->_config_oracle.configs[cm->estimators[live_slot].first.config_index].exclusions;
auto& interactions = cm->estimators[live_slot].first.live_interactions;
gen_interactions_from_exclusions(cm->_ccb_on, cm->ns_counter, cm->interaction_type, exclusions, interactions);
}
}
cm->_config_oracle.gen_exclusion_configs(cm->estimators[cm->current_champ].first.live_interactions);
Expand All @@ -367,8 +369,9 @@ void automl<CMType>::one_step(multi_learner& base, multi_ex& ec, CB::cb_class& l
{
for (uint64_t live_slot = 0; live_slot < cm->estimators.size(); ++live_slot)
{
gen_interactions_from_exclusions(
cm->_ccb_on, cm->ns_counter, cm->interaction_type, cm->_config_oracle.configs, cm->estimators, live_slot);
auto& exclusions = cm->_config_oracle.configs[cm->estimators[live_slot].first.config_index].exclusions;
auto& interactions = cm->estimators[live_slot].first.live_interactions;
gen_interactions_from_exclusions(cm->_ccb_on, cm->ns_counter, cm->interaction_type, exclusions, interactions);
}
}
cm->schedule();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@ size_t read_model_field(io_buf& io, VW::reductions::automl::interaction_config_m
bytes += read_model_field(io, cm.per_live_model_state_uint64);
for (uint64_t live_slot = 0; live_slot < cm.estimators.size(); ++live_slot)
{
gen_interactions_from_exclusions(
cm._ccb_on, cm.ns_counter, cm.interaction_type, cm._config_oracle.configs, cm.estimators, live_slot);
auto& exclusions = cm._config_oracle.configs[cm.estimators[live_slot].first.config_index].exclusions;
auto& interactions = cm.estimators[live_slot].first.live_interactions;
reductions::automl::gen_interactions_from_exclusions(
cm._ccb_on, cm.ns_counter, cm.interaction_type, exclusions, interactions);
}
return bytes;
}
Expand Down
11 changes: 3 additions & 8 deletions vowpalwabbit/core/src/reductions/details/automl/automl_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@ namespace automl
// from the corresponding live_slot. This function can be swapped out depending on
// preference of how to generate interactions from a given set of exclusions.
// Transforms exclusions -> interactions expected by VW.
void gen_interactions_from_exclusions(bool ccb_on, std::map<namespace_index, uint64_t>& ns_counter,
std::string& interaction_type, std::vector<exclusion_config>& configs, estimator_vec_t& estimators,
uint64_t live_slot)
void gen_interactions_from_exclusions(const bool ccb_on, const std::map<namespace_index, uint64_t>& ns_counter,
const std::string& interaction_type, const std::set<std::vector<namespace_index>>& exclusions,
interaction_vec_t& interactions)
{
if (interaction_type == "quadratic")
{
auto& exclusions = configs[estimators[live_slot].first.config_index].exclusions;
auto& interactions = estimators[live_slot].first.live_interactions;
if (!interactions.empty()) { interactions.clear(); }
for (auto it = ns_counter.begin(); it != ns_counter.end(); ++it)
{
Expand All @@ -40,8 +38,6 @@ void gen_interactions_from_exclusions(bool ccb_on, std::map<namespace_index, uin
}
else if (interaction_type == "cubic")
{
auto& exclusions = configs[estimators[live_slot].first.config_index].exclusions;
auto& interactions = estimators[live_slot].first.live_interactions;
if (!interactions.empty()) { interactions.clear(); }
for (auto it = ns_counter.begin(); it != ns_counter.end(); ++it)
{
Expand All @@ -66,7 +62,6 @@ void gen_interactions_from_exclusions(bool ccb_on, std::map<namespace_index, uin
if (ccb_on)
{
std::vector<std::vector<extent_term>> empty;
auto& interactions = estimators[live_slot].first.live_interactions;
ccb::insert_ccb_interactions(interactions, empty);
}
}
Expand Down
6 changes: 3 additions & 3 deletions vowpalwabbit/core/src/reductions/details/automl_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,9 @@ struct interaction_config_manager : config_manager
};

bool count_namespaces(const multi_ex& ecs, std::map<namespace_index, uint64_t>& ns_counter);
void gen_interactions_from_exclusions(bool ccb_on, std::map<namespace_index, uint64_t>& ns_counter,
std::string& interaction_type, std::vector<exclusion_config>& configs, estimator_vec_t& estimators,
uint64_t live_slot);
void gen_interactions_from_exclusions(const bool ccb_on, const std::map<namespace_index, uint64_t>& ns_counter,
const std::string& interaction_type, const std::set<std::vector<namespace_index>>& exclusions,
interaction_vec_t& interactions);
void apply_config(example* ec, interaction_vec_t* live_interactions);
bool is_allowed_to_remove(const unsigned char ns);
void clear_non_champ_weights(dense_parameters& weights, uint32_t total, uint32_t& wpp);
Expand Down

0 comments on commit 097db20

Please sign in to comment.