From 097db2026a43d8dd9898465750d8336a1f12897b Mon Sep 17 00:00:00 2001 From: Eduardo Salinas Date: Mon, 15 Aug 2022 19:49:53 +0000 Subject: [PATCH] remove estimator dependency from gen interactions --- test/unit_test/automl_test.cc | 8 ++++++-- .../src/reductions/details/automl/automl_impl.cc | 15 +++++++++------ .../reductions/details/automl/automl_iomodel.cc | 6 ++++-- .../src/reductions/details/automl/automl_util.cc | 11 +++-------- .../core/src/reductions/details/automl_impl.h | 6 +++--- 5 files changed, 25 insertions(+), 21 deletions(-) diff --git a/test/unit_test/automl_test.cc b/test/unit_test/automl_test.cc index 210ff3b211f..fb6c73ee6f9 100644 --- a/test/unit_test/automl_test.cc +++ b/test/unit_test/automl_test.cc @@ -460,7 +460,9 @@ BOOST_AUTO_TEST_CASE(one_diff_impl_unittest) auto& champ_interactions = estimators[CHAMP].first.live_interactions; BOOST_CHECK_EQUAL(champ_interactions.size(), 0); - gen_interactions_from_exclusions(false, ns_counter, oracle._interaction_type, configs, estimators, 0); + auto& exclusions = oracle.configs[estimators[0].first.config_index].exclusions; + auto& interactions = estimators[0].first.live_interactions; + gen_interactions_from_exclusions(false, ns_counter, oracle._interaction_type, exclusions, interactions); BOOST_CHECK_EQUAL(champ_interactions.size(), 3); const std::vector first = {'A', 'A'}; @@ -498,7 +500,9 @@ BOOST_AUTO_TEST_CASE(one_diff_impl_unittest) interaction_config_manager>::apply_config_at_slot(estimators, oracle.configs, i, interaction_config_manager>::choose(oracle.index_queue), aml->cm->automl_significance_level, aml->cm->automl_estimator_decay, 1); - gen_interactions_from_exclusions(false, ns_counter, oracle._interaction_type, configs, estimators, i); + auto& temp_exclusions = oracle.configs[estimators[i].first.config_index].exclusions; + auto& temp_interactions = estimators[i].first.live_interactions; + gen_interactions_from_exclusions(false, ns_counter, oracle._interaction_type, temp_exclusions, temp_interactions); } BOOST_CHECK_EQUAL(prio_queue.size(), 0); BOOST_CHECK_EQUAL(estimators.size(), 4); diff --git a/vowpalwabbit/core/src/reductions/details/automl/automl_impl.cc b/vowpalwabbit/core/src/reductions/details/automl/automl_impl.cc index a02d623d297..2d74d67e068 100644 --- a/vowpalwabbit/core/src/reductions/details/automl/automl_impl.cc +++ b/vowpalwabbit/core/src/reductions/details/automl/automl_impl.cc @@ -170,8 +170,9 @@ void interaction_config_manager::schedule() // copy the weights of the champ to the new slot weights.move_offsets(current_champ, live_slot, wpp); // Regenerate interactions each time an exclusion is swapped in - gen_interactions_from_exclusions( - _ccb_on, ns_counter, interaction_type, _config_oracle.configs, estimators, live_slot); + gen_interactions_from_exclusions(_ccb_on, ns_counter, interaction_type, + _config_oracle.configs[estimators[live_slot].first.config_index].exclusions, + estimators[live_slot].first.live_interactions); } } } @@ -351,8 +352,9 @@ void automl::one_step(multi_learner& base, multi_ex& ec, CB::cb_class& l { for (uint64_t live_slot = 0; live_slot < cm->estimators.size(); ++live_slot) { - gen_interactions_from_exclusions( - cm->_ccb_on, cm->ns_counter, cm->interaction_type, cm->_config_oracle.configs, cm->estimators, live_slot); + auto& exclusions = cm->_config_oracle.configs[cm->estimators[live_slot].first.config_index].exclusions; + auto& interactions = cm->estimators[live_slot].first.live_interactions; + gen_interactions_from_exclusions(cm->_ccb_on, cm->ns_counter, cm->interaction_type, exclusions, interactions); } } cm->_config_oracle.gen_exclusion_configs(cm->estimators[cm->current_champ].first.live_interactions); @@ -367,8 +369,9 @@ void automl::one_step(multi_learner& base, multi_ex& ec, CB::cb_class& l { for (uint64_t live_slot = 0; live_slot < cm->estimators.size(); ++live_slot) { - gen_interactions_from_exclusions( - cm->_ccb_on, cm->ns_counter, cm->interaction_type, cm->_config_oracle.configs, cm->estimators, live_slot); + auto& exclusions = cm->_config_oracle.configs[cm->estimators[live_slot].first.config_index].exclusions; + auto& interactions = cm->estimators[live_slot].first.live_interactions; + gen_interactions_from_exclusions(cm->_ccb_on, cm->ns_counter, cm->interaction_type, exclusions, interactions); } } cm->schedule(); diff --git a/vowpalwabbit/core/src/reductions/details/automl/automl_iomodel.cc b/vowpalwabbit/core/src/reductions/details/automl/automl_iomodel.cc index 9032df7ce5b..b575bba59f2 100644 --- a/vowpalwabbit/core/src/reductions/details/automl/automl_iomodel.cc +++ b/vowpalwabbit/core/src/reductions/details/automl/automl_iomodel.cc @@ -109,8 +109,10 @@ size_t read_model_field(io_buf& io, VW::reductions::automl::interaction_config_m bytes += read_model_field(io, cm.per_live_model_state_uint64); for (uint64_t live_slot = 0; live_slot < cm.estimators.size(); ++live_slot) { - gen_interactions_from_exclusions( - cm._ccb_on, cm.ns_counter, cm.interaction_type, cm._config_oracle.configs, cm.estimators, live_slot); + auto& exclusions = cm._config_oracle.configs[cm.estimators[live_slot].first.config_index].exclusions; + auto& interactions = cm.estimators[live_slot].first.live_interactions; + reductions::automl::gen_interactions_from_exclusions( + cm._ccb_on, cm.ns_counter, cm.interaction_type, exclusions, interactions); } return bytes; } diff --git a/vowpalwabbit/core/src/reductions/details/automl/automl_util.cc b/vowpalwabbit/core/src/reductions/details/automl/automl_util.cc index 7badf06962b..c20cbb94f0d 100644 --- a/vowpalwabbit/core/src/reductions/details/automl/automl_util.cc +++ b/vowpalwabbit/core/src/reductions/details/automl/automl_util.cc @@ -18,14 +18,12 @@ namespace automl // from the corresponding live_slot. This function can be swapped out depending on // preference of how to generate interactions from a given set of exclusions. // Transforms exclusions -> interactions expected by VW. -void gen_interactions_from_exclusions(bool ccb_on, std::map& ns_counter, - std::string& interaction_type, std::vector& configs, estimator_vec_t& estimators, - uint64_t live_slot) +void gen_interactions_from_exclusions(const bool ccb_on, const std::map& ns_counter, + const std::string& interaction_type, const std::set>& exclusions, + interaction_vec_t& interactions) { if (interaction_type == "quadratic") { - auto& exclusions = configs[estimators[live_slot].first.config_index].exclusions; - auto& interactions = estimators[live_slot].first.live_interactions; if (!interactions.empty()) { interactions.clear(); } for (auto it = ns_counter.begin(); it != ns_counter.end(); ++it) { @@ -40,8 +38,6 @@ void gen_interactions_from_exclusions(bool ccb_on, std::map> empty; - auto& interactions = estimators[live_slot].first.live_interactions; ccb::insert_ccb_interactions(interactions, empty); } } diff --git a/vowpalwabbit/core/src/reductions/details/automl_impl.h b/vowpalwabbit/core/src/reductions/details/automl_impl.h index 4f82f89074e..462c08e41c1 100644 --- a/vowpalwabbit/core/src/reductions/details/automl_impl.h +++ b/vowpalwabbit/core/src/reductions/details/automl_impl.h @@ -191,9 +191,9 @@ struct interaction_config_manager : config_manager }; bool count_namespaces(const multi_ex& ecs, std::map& ns_counter); -void gen_interactions_from_exclusions(bool ccb_on, std::map& ns_counter, - std::string& interaction_type, std::vector& configs, estimator_vec_t& estimators, - uint64_t live_slot); +void gen_interactions_from_exclusions(const bool ccb_on, const std::map& ns_counter, + const std::string& interaction_type, const std::set>& exclusions, + interaction_vec_t& interactions); void apply_config(example* ec, interaction_vec_t* live_interactions); bool is_allowed_to_remove(const unsigned char ns); void clear_non_champ_weights(dense_parameters& weights, uint32_t total, uint32_t& wpp);