-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Model merging
Model merging takes several compatible VW models and merges them into a single model that approximately represents all of the models combined. This is will probably never be as effective as a single model trained with all of the data sequentially. However, for situations where it is not feasible to train against all data sequentially the speedup from parallel computation can make a merged model which sees all data potentially more effective than a model trained on a subset of the data.
Generally speaking, merging is a weighted average of all given models based on relative amount of data processed. Values which act as counters are accumulated instead of averaged.
In the case of the GD
reduction, when save_resume
is in use, then the adaptive
values are used to do a per model parameter weighted average. For all other averaged values in a model, the number of examples seen by a model is used for the given weighted average.
If a reduction defines a save_load
function this implies that the reduction has training state which is persisted. Therefore, a rule of thumb is that if a reduction defines save_load
it must also define merge
. A warning will be emitted if any of the reductions in the stack have a save_load
but no merge
and an error will be emitted if the base reduction in a stack has no merge
as it will definitely not work in that case.
The signature of the merge function depends on if the reduction is a base or not. Ideally, all merge
functions would use the non-base reduction signature but since base learners use the weights and other state in VW::Workspace
it is not currently feasible.
using ReductionDataT = void; // ...
// Base reduction
using merge_with_all_fn = void (*)(const std::vector<float>& example_counts,
const std::vector<const VW::workspace*>& all_workspaces, const std::vector<const ReductionDataT*>& all_data,
VW::workspace& output_workspace, ReductionDataT& output_data);
// Non-base reduction
using merge_fn = void (*)(
const std::vector<float>& example_counts, const std::vector<const ReductionDataT*>& all_data, ReductionDataT& output_data);
This is then set on the respective learner builder during construction.
merge
is then exposed by the learner
interface.
- Home
- First Steps
- Input
- Command line arguments
- Model saving and loading
- Controlling VW's output
- Audit
- Algorithm details
- Awesome Vowpal Wabbit
- Learning algorithm
- Learning to Search subsystem
- Loss functions
- What is a learner?
- Docker image
- Model merging
- Evaluation of exploration algorithms
- Reductions
- Contextual Bandit algorithms
- Contextual Bandit Exploration with SquareCB
- Contextual Bandit Zeroth Order Optimization
- Conditional Contextual Bandit
- Slates
- CATS, CATS-pdf for Continuous Actions
- Automl
- Epsilon Decay
- Warm starting contextual bandits
- Efficient Second Order Online Learning
- Latent Dirichlet Allocation
- VW Reductions Workflows
- Interaction Grounded Learning
- CB with Large Action Spaces
- CB with Graph Feedback
- FreeGrad
- Marginal
- Active Learning
- Eigen Memory Trees (EMT)
- Element-wise interaction
- Bindings
-
Examples
- Logged Contextual Bandit example
- One Against All (oaa) multi class example
- Weighted All Pairs (wap) multi class example
- Cost Sensitive One Against All (csoaa) multi class example
- Multiclass classification
- Error Correcting Tournament (ect) multi class example
- Malicious URL example
- Daemon example
- Matrix factorization example
- Rcv1 example
- Truncated gradient descent example
- Scripts
- Implement your own joint prediction model
- Predicting probabilities
- murmur2 vs murmur3
- Weight vector
- Matching Label and Prediction Types Between Reductions
- Zhen's Presentation Slides on enhancements to vw
- EZExample Archive
- Design Documents
- Contribute: