Skip to content

feat: Model merging with delta objects #4177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 104 additions & 21 deletions vowpalwabbit/core/include/vw/core/learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,15 @@ struct finish_example_data
fn print_example_f = nullptr;
};

using merge_with_all_fn = void (*)(const std::vector<float>& per_model_weighting, const VW::workspace& base_workspace,
const std::vector<const VW::workspace*>& all_workspaces, const void* base_data,
const std::vector<const void*>& all_data, VW::workspace& output_workspace, void* output_data);
using merge_with_all_fn = void (*)(const std::vector<float>& per_model_weighting,
const std::vector<const VW::workspace*>& all_workspaces, const std::vector<const void*>& all_data,
VW::workspace& output_workspace, void* output_data);
// When the workspace reference is not needed this signature should definitely be used.
using merge_fn = void (*)(const std::vector<float>& per_model_weighting, const void* base_data,
const std::vector<const void*>& all_data, void* output_data);
using merge_fn = void (*)(
const std::vector<float>& per_model_weighting, const std::vector<const void*>& all_data, void* output_data);
using add_subtract_fn = void (*)(const void* data_1, const void* data_2, void* data_out);
using add_subtract_with_all_fn = void (*)(const VW::workspace& ws1, const void* data1, const VW::workspace& ws2,
const void* data2, VW::workspace& ws_out, void* data_out);

inline void noop_save_load(void*, io_buf&, bool, bool) {}
inline void noop_persist_metrics(void*, metric_sink&) {}
Expand Down Expand Up @@ -253,6 +256,10 @@ struct learner
// There should only only ever be either none, or one of these two set. Never both.
details::merge_with_all_fn _merge_with_all_fn;
details::merge_fn _merge_fn;
details::add_subtract_fn _add_fn;
details::add_subtract_with_all_fn _add_with_all_fn;
details::add_subtract_fn _subtract_fn;
details::add_subtract_with_all_fn _subtract_with_all_fn;

std::shared_ptr<void> learner_data;

Expand Down Expand Up @@ -466,17 +473,17 @@ struct learner

// This is effectively static implementing a trait for this learner type.
// NOT auto recursive
void NO_SANITIZE_UNDEFINED merge(const std::vector<float>& per_model_weighting, const VW::workspace& base_workspace,
const std::vector<const VW::workspace*>& all_workspaces, const base_learner* base_workspaces_learner,
const std::vector<const base_learner*>& all_learners, VW::workspace& output_workspace,
base_learner* output_learner)
void NO_SANITIZE_UNDEFINED merge(const std::vector<float>& per_model_weighting,
const std::vector<const VW::workspace*>& all_workspaces, const std::vector<const base_learner*>& all_learners,
VW::workspace& output_workspace, base_learner& output_learner)
{
assert(per_model_weighting.size() == all_workspaces.size());
assert(per_model_weighting.size() == all_learners.size());

#ifndef NDEBUG
// All learners should refer to the same learner 'type'
const auto& name = base_workspaces_learner->get_name();
assert(!all_learners.empty());
const auto& name = all_learners[0]->get_name();
for (const auto& learner : all_learners) { assert(learner->get_name() == name); }
#endif

Expand All @@ -486,21 +493,67 @@ struct learner

if (_merge_with_all_fn != nullptr)
{
_merge_with_all_fn(per_model_weighting, base_workspace, all_workspaces,
base_workspaces_learner->learner_data.get(), all_data, output_workspace, output_learner->learner_data.get());
_merge_with_all_fn(
per_model_weighting, all_workspaces, all_data, output_workspace, output_learner.learner_data.get());
}
else if (_merge_fn != nullptr)
{
_merge_fn(per_model_weighting, base_workspaces_learner->learner_data.get(), all_data,
output_learner->learner_data.get());
_merge_fn(per_model_weighting, all_data, output_learner.learner_data.get());
}
else
{
THROW("learner " << name << " does not support merging.");
}
}

void NO_SANITIZE_UNDEFINED add(const VW::workspace& base_ws, const VW::workspace& delta_ws,
const base_learner* base_l, const base_learner* delta_l, VW::workspace& output_ws, base_learner* output_l)
{
auto name = output_l->get_name();
assert(name == base_l->get_name());
assert(name == delta_l->get_name());
if (_add_with_all_fn != nullptr)
{
_add_with_all_fn(base_ws, base_l->learner_data.get(), delta_ws, delta_l->learner_data.get(), output_ws,
output_l->learner_data.get());
}
else if (_add_fn != nullptr)
{
_add_fn(base_l->learner_data.get(), delta_l->learner_data.get(), output_l->learner_data.get());
}
else
{
THROW("learner " << name << " does not support adding a delta.");
}
}

void NO_SANITIZE_UNDEFINED subtract(const VW::workspace& ws1, const VW::workspace& ws2, const base_learner* l1,
const base_learner* l2, VW::workspace& output_ws, base_learner* output_l)
{
auto name = output_l->get_name();
assert(name == l1->get_name());
assert(name == l2->get_name());
if (_subtract_with_all_fn != nullptr)
{
_subtract_with_all_fn(
ws1, l1->learner_data.get(), ws2, l2->learner_data.get(), output_ws, output_l->learner_data.get());
}
else if (_subtract_fn != nullptr)
{
_subtract_fn(l1->learner_data.get(), l2->learner_data.get(), output_l->learner_data.get());
}
else
{
THROW("learner " << name << " does not support subtraction to generate a delta.");
}
}

VW_ATTR(nodiscard) bool has_merge() const { return (_merge_with_all_fn != nullptr) || (_merge_fn != nullptr); }
VW_ATTR(nodiscard) bool has_add() const { return (_add_with_all_fn != nullptr) || (_add_fn != nullptr); }
VW_ATTR(nodiscard) bool has_subtract() const
{
return (_subtract_with_all_fn != nullptr) || (_subtract_fn != nullptr);
}
VW_ATTR(nodiscard) prediction_type_t get_output_prediction_type() const { return _output_pred_type; }
VW_ATTR(nodiscard) prediction_type_t get_input_prediction_type() const { return _input_pred_type; }
VW_ATTR(nodiscard) label_type_t get_output_label_type() const { return _output_label_type; }
Expand Down Expand Up @@ -741,6 +794,10 @@ struct reduction_learner_builder
// Don't propagate merge functions
this->_learner->_merge_fn = nullptr;
this->_learner->_merge_with_all_fn = nullptr;
this->_learner->_add_fn = nullptr;
this->_learner->_add_with_all_fn = nullptr;
this->_learner->_subtract_fn = nullptr;
this->_learner->_subtract_with_all_fn = nullptr;

set_params_per_weight(1);
this->set_learn_returns_prediction(false);
Expand All @@ -762,14 +819,27 @@ struct reduction_learner_builder
return *this;
}

reduction_learner_builder<DataT, ExampleT, BaseLearnerT>& set_merge(
void (*merge_fn)(const std::vector<float>& per_model_weighting, const DataT& base_data,
const std::vector<const DataT*>& all_data, DataT& output_data))
reduction_learner_builder<DataT, ExampleT, BaseLearnerT>& set_merge(void (*merge_fn)(
const std::vector<float>& per_model_weighting, const std::vector<const DataT*>& all_data, DataT& output_data))
{
this->_learner->_merge_fn = reinterpret_cast<details::merge_fn>(merge_fn);
return *this;
}

reduction_learner_builder<DataT, ExampleT, BaseLearnerT>& set_add(
void (*add_fn)(const DataT& data1, const DataT& data2, DataT& data_out))
{
this->_learner->_add_fn = reinterpret_cast<details::add_subtract_fn>(add_fn);
return *this;
}

reduction_learner_builder<DataT, ExampleT, BaseLearnerT>& set_subtract(
void (*subtract_fn)(const DataT& data1, const DataT& data2, DataT& data_out))
{
this->_learner->_subtract_fn = reinterpret_cast<details::add_subtract_fn>(subtract_fn);
return *this;
}

learner<DataT, ExampleT>* build(VW::io::logger* logger = nullptr)
{
if (logger != nullptr)
Expand Down Expand Up @@ -889,15 +959,28 @@ struct base_learner_builder
return *this;
}

base_learner_builder<DataT, ExampleT>& set_merge_with_all(
void (*merge_with_all_fn)(const std::vector<float>& per_model_weighting, const VW::workspace& base_workspace,
const std::vector<const VW::workspace*>& all_workspaces, const DataT& base_data,
const std::vector<DataT*>& all_data, VW::workspace& output_workspace, DataT& output_data))
base_learner_builder<DataT, ExampleT>& set_merge_with_all(void (*merge_with_all_fn)(
const std::vector<float>& per_model_weighting, const std::vector<const VW::workspace*>& all_workspaces,
const std::vector<DataT*>& all_data, VW::workspace& output_workspace, DataT& output_data))
{
this->_learner->_merge_with_all_fn = reinterpret_cast<details::merge_with_all_fn>(merge_with_all_fn);
return *this;
}

base_learner_builder<DataT, ExampleT>& set_add_with_all(void (*add_with_all_fn)(const VW::workspace& ws1,
const DataT& data1, const VW::workspace& ws2, DataT& data2, VW::workspace& ws_out, DataT& data_out))
{
this->_learner->_add_with_all_fn = reinterpret_cast<details::add_subtract_with_all_fn>(add_with_all_fn);
return *this;
}

base_learner_builder<DataT, ExampleT>& set_subtract_with_all(void (*subtract_with_all_fn)(const VW::workspace& ws1,
const DataT& data1, const VW::workspace& ws2, DataT& data2, VW::workspace& ws_out, DataT& data_out))
{
this->_learner->_subtract_with_all_fn = reinterpret_cast<details::add_subtract_with_all_fn>(subtract_with_all_fn);
return *this;
}

learner<DataT, ExampleT>* build()
{
if (this->_learner->_merge_fn != nullptr && this->_learner->_merge_with_all_fn != nullptr)
Expand Down
43 changes: 41 additions & 2 deletions vowpalwabbit/core/include/vw/core/merge.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,28 @@
#include "vw/core/global_data.h"
#include "vw/io/logger.h"

#include <memory>

namespace VW
{
struct model_delta
{
// model_delta takes ownership of the workspace
explicit model_delta(VW::workspace* ws) : _ws(ws) {}
explicit model_delta(std::unique_ptr<VW::workspace>&& ws) : _ws(std::move(ws)) {}

// retrieve a raw pointer to the underlying VW::workspace
// unsafe, only for use in implementation of model merging and its corresponding unit tests
VW::workspace* unsafe_get_workspace_ptr() const { return _ws.get(); }

// release ownership of the pointer to the underlying VW::workspace, and return it
// unsafe, only for use in implementation of model merging and its corresponding unit tests
VW::workspace* unsafe_release_workspace_ptr() { return _ws.release(); }

private:
std::unique_ptr<VW::workspace> _ws;
};

/**
* Merge the differences of several workspaces into the given base workspace. This merges weights
* and all training state. All given workspaces must be compatible with the base workspace, meaning they
Expand All @@ -18,10 +38,29 @@ namespace VW
*
* @param base_workspace Optional common base model that all other models continued training from. If not supplied, then
* all models are assumed to be trained from scratch.
* @param workspaces_to_merge List of workspaces to merge.
* @param workspaces_to_merge Vector of workspaces to merge.
* @param logger Optional logger to be used for logging during function and is given to the resulting workspace
* @return std::unique_ptr<VW::workspace> Pointer to the resulting workspace.
*/
std::unique_ptr<VW::workspace> merge_models(const VW::workspace* base_workspace,
const std::vector<const VW::workspace*>& workspaces_to_merge, VW::io::logger* logger = nullptr);
} // namespace VW

/**
* Merge several model deltas into a single delta. This merges weights
* and all training state. All given deltas must be from compatible models, meaning they
* should have the same reduction stack and same training based options. All deltas are
* assumed to be generated using a single shared base workspace.
*
* Note: This is an experimental API.
*
* @param deltas_to_merge Vector of model deltas to merge.
* @param logger Optional logger to be used for logging during function and is given to the resulting workspace
* @return std::unique_ptr<VW::workspace> Pointer to the resulting workspace.
*/
VW::model_delta merge_deltas(
const std::vector<const VW::model_delta*>& deltas_to_merge, VW::io::logger* logger = nullptr);
} // namespace VW

std::unique_ptr<VW::workspace> operator+(const VW::workspace& ws, const VW::model_delta& md);

VW::model_delta operator-(const VW::workspace& ws1, const VW::workspace& ws2);
Loading