Skip to content

Commit a9b305f

Browse files
authored
refactor: [workspace] split 'all' into multiple structs. split config and runtime vars. (#4493)
1 parent adcaff2 commit a9b305f

File tree

141 files changed

+2296
-1831
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

141 files changed

+2296
-1831
lines changed

cs/cli/vowpalwabbit.cpp

+19-19
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,16 @@ VowpalWabbit::VowpalWabbit(VowpalWabbitSettings^ settings)
3232
}
3333

3434
if (settings->ParallelOptions != nullptr)
35-
{ m_vw->selected_all_reduce_type = all_reduce_type::THREAD;
35+
{ m_vw->runtime_config.selected_all_reduce_type = all_reduce_type::THREAD;
3636
auto total = settings->ParallelOptions->MaxDegreeOfParallelism;
3737

3838
if (settings->Root == nullptr)
39-
{ m_vw->all_reduce.reset(new all_reduce_threads(total, settings->Node));
39+
{ m_vw->runtime_state.all_reduce.reset(new all_reduce_threads(total, settings->Node));
4040
}
4141
else
42-
{ auto parent_all_reduce = (all_reduce_threads*)settings->Root->m_vw->all_reduce.get();
42+
{ auto parent_all_reduce = (all_reduce_threads*)settings->Root->m_vw->runtime_state.all_reduce.get();
4343

44-
m_vw->all_reduce.reset(new all_reduce_threads(parent_all_reduce, total, settings->Node));
44+
m_vw->runtime_state.all_reduce.reset(new all_reduce_threads(parent_all_reduce, total, settings->Node));
4545
}
4646
}
4747

@@ -64,9 +64,9 @@ void VowpalWabbit::Driver()
6464
}
6565

6666
void VowpalWabbit::RunMultiPass()
67-
{ if (m_vw->numpasses > 1)
67+
{ if (m_vw->runtime_config.numpasses > 1)
6868
{ try
69-
{ m_vw->do_reset_source = true;
69+
{ m_vw->runtime_state.do_reset_source = true;
7070
VW::start_parser(*m_vw);
7171
LEARNER::generic_driver(*m_vw);
7272
VW::end_parser(*m_vw);
@@ -79,17 +79,17 @@ VowpalWabbitPerformanceStatistics^ VowpalWabbit::PerformanceStatistics::get()
7979
{ // see parse_args.cc:finish(...)
8080
auto stats = gcnew VowpalWabbitPerformanceStatistics();
8181

82-
if (m_vw->current_pass == 0)
82+
if (m_vw->passes_config.current_pass == 0)
8383
{ stats->NumberOfExamplesPerPass = m_vw->sd->example_number;
8484
}
8585
else
86-
{ stats->NumberOfExamplesPerPass = m_vw->sd->example_number / m_vw->current_pass;
86+
{ stats->NumberOfExamplesPerPass = m_vw->sd->example_number / m_vw->passes_config.current_pass;
8787
}
8888

8989
stats->WeightedExampleSum = m_vw->sd->weighted_examples();
9090
stats->WeightedLabelSum = m_vw->sd->weighted_labels;
9191

92-
if (m_vw->holdout_set_off)
92+
if (m_vw->passes_config.holdout_set_off)
9393
if (m_vw->sd->weighted_labeled_examples > 0)
9494
stats->AverageLoss = m_vw->sd->sum_loss / m_vw->sd->weighted_labeled_examples;
9595
else
@@ -100,7 +100,7 @@ VowpalWabbitPerformanceStatistics^ VowpalWabbit::PerformanceStatistics::get()
100100
stats->AverageLoss = m_vw->sd->holdout_best_loss;
101101

102102
float best_constant; float best_constant_loss;
103-
if (get_best_constant(*m_vw->loss, *m_vw->sd, best_constant, best_constant_loss))
103+
if (get_best_constant(*m_vw->loss_config.loss, *m_vw->sd, best_constant, best_constant_loss))
104104
{ stats->BestConstant = best_constant;
105105
if (best_constant_loss != FLT_MIN)
106106
{ stats->BestConstantLoss = best_constant_loss;
@@ -124,7 +124,7 @@ uint64_t VowpalWabbit::HashSpace(String^ s)
124124
}
125125

126126
uint64_t VowpalWabbit::HashFeature(String^ s, size_t u)
127-
{ auto newHash = m_hasher(s, u) & m_vw->parse_mask;
127+
{ auto newHash = m_hasher(s, u) & m_vw->runtime_state.parse_mask;
128128

129129
#ifdef _DEBUG
130130
auto oldHash = HashFeatureNative(s, u);
@@ -321,7 +321,7 @@ List<VowpalWabbitExample^>^ VowpalWabbit::ParseDecisionServiceJson(cli::array<By
321321

322322
VW::parsers::json::decision_service_interaction interaction;
323323

324-
if (m_vw->audit)
324+
if (m_vw->output_config.audit)
325325
VW::parsers::json::read_line_decision_service_json<true>(*m_vw, examples, reinterpret_cast<char*>(data), length, copyJson, std::bind(get_example_from_pool, &state), &interaction);
326326
else
327327
VW::parsers::json::read_line_decision_service_json<false>(*m_vw, examples, reinterpret_cast<char*>(data), length, copyJson, std::bind(get_example_from_pool, &state), &interaction);
@@ -385,7 +385,7 @@ List<VowpalWabbitExample^>^ VowpalWabbit::ParseDecisionServiceJson(cli::array<By
385385

386386
interior_ptr<ParseJsonState^> state_ptr = &state;
387387

388-
if (m_vw->audit)
388+
if (m_vw->output_config.audit)
389389
VW::parsers::json::read_line_json<true>(*m_vw, examples, reinterpret_cast<char*>(valueHandle.AddrOfPinnedObject().ToPointer()), (size_t)bytes->Length, std::bind(get_example_from_pool, &state));
390390
else
391391
VW::parsers::json::read_line_json<false>(*m_vw, examples, reinterpret_cast<char*>(valueHandle.AddrOfPinnedObject().ToPointer()), (size_t)bytes->Length, std::bind(get_example_from_pool, &state));
@@ -793,15 +793,15 @@ VowpalWabbitExample^ VowpalWabbit::GetOrCreateNativeExample()
793793
{ try
794794
{
795795
auto ex = new VW::example;
796-
m_vw->example_parser->lbl_parser.default_label(ex->l);
796+
m_vw->parser_runtime.example_parser->lbl_parser.default_label(ex->l);
797797
return gcnew VowpalWabbitExample(this, ex);
798798
}
799799
CATCHRETHROW
800800
}
801801

802802
try
803803
{ VW::empty_example(*m_vw, *ex->m_example);
804-
m_vw->example_parser->lbl_parser.default_label(ex->m_example->l);
804+
m_vw->parser_runtime.example_parser->lbl_parser.default_label(ex->m_example->l);
805805

806806
return ex;
807807
}
@@ -833,9 +833,9 @@ void VowpalWabbit::ReturnExampleToPool(VowpalWabbitExample^ ex)
833833
}
834834

835835
cli::array<List<VowpalWabbitFeature^>^>^ VowpalWabbit::GetTopicAllocation(int top)
836-
{ uint64_t length = (uint64_t)1 << m_vw->num_bits;
836+
{ uint64_t length = (uint64_t)1 << m_vw->initial_weights_config.num_bits;
837837
// using jagged array to enable LINQ
838-
auto K = (int)m_vw->lda;
838+
auto K = (int)m_vw->reduction_state.lda;
839839
auto allocation = gcnew cli::array<List<VowpalWabbitFeature^>^>(K);
840840

841841
// TODO: better way of peaking into lda?
@@ -858,10 +858,10 @@ cli::array<List<VowpalWabbitFeature^>^>^ VowpalWabbit::GetTopicAllocation(int to
858858
template<typename T>
859859
cli::array<cli::array<float>^>^ VowpalWabbit::FillTopicAllocation(T& weights)
860860
{
861-
uint64_t length = (uint64_t)1 << m_vw->num_bits;
861+
uint64_t length = (uint64_t)1 << m_vw->initial_weights_config.num_bits;
862862

863863
// using jagged array to enable LINQ
864-
auto K = (int)m_vw->lda;
864+
auto K = (int)m_vw->reduction_state.lda;
865865
auto allocation = gcnew cli::array<cli::array<float>^>(K);
866866
for (int k = 0; k < K; k++)
867867
allocation[k] = gcnew cli::array<float>((int)length);

cs/cli/vw_arguments.h

+8-9
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,17 @@ public ref class VowpalWabbitArguments
3636
float m_power_t;
3737

3838
internal : VowpalWabbitArguments(VW::workspace* vw)
39-
: m_data(gcnew String(vw->data_filename.c_str()))
40-
, m_finalRegressor(gcnew String(vw->final_regressor_name.c_str()))
41-
, m_testonly(!vw->training)
42-
, m_passes((int)vw->numpasses)
39+
: m_data(gcnew String(vw->parser_runtime.data_filename.c_str()))
40+
, m_finalRegressor(gcnew String(vw->output_model_config.final_regressor_name.c_str()))
41+
, m_testonly(!vw->runtime_config.training)
42+
, m_passes((int)vw->runtime_config.numpasses)
4343
{
4444
auto options = vw->options.get();
4545

46-
if (vw->initial_regressors.size() > 0)
46+
if (vw->initial_weights_config.initial_regressors.size() > 0)
4747
{ m_regressors = gcnew List<String^>;
4848

49-
for (auto& r : vw->initial_regressors)
50-
m_regressors->Add(gcnew String(r.c_str()));
49+
for (auto& r : vw->initial_weights_config.initial_regressors) m_regressors->Add(gcnew String(r.c_str()));
5150
}
5251

5352
VW::config::cli_options_serializer serializer;
@@ -66,8 +65,8 @@ public ref class VowpalWabbitArguments
6665
m_numberOfActions = (int)options->get_typed_option<uint32_t>("cb").value();
6766
}
6867

69-
m_learning_rate = vw->eta;
70-
m_power_t = vw->power_t;
68+
m_learning_rate = vw->update_rule_config.eta;
69+
m_power_t = vw->update_rule_config.power_t;
7170
}
7271

7372
public:

cs/cli/vw_base.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ void VowpalWabbitBase::InternalDispose()
150150
try
151151
{ if (m_vw != nullptr)
152152
{
153-
VW::details::reset_source(*m_vw, m_vw->num_bits);
153+
VW::details::reset_source(*m_vw, m_vw->initial_weights_config.num_bits);
154154

155155
// make sure don't try to free m_vw twice in case VW::finish throws.
156156
VW::workspace* vw_tmp = m_vw;
@@ -187,7 +187,7 @@ void VowpalWabbitBase::Reload([System::Runtime::InteropServices::Optional] Strin
187187

188188
try
189189
{
190-
VW::details::reset_source(*m_vw, m_vw->num_bits);
190+
VW::details::reset_source(*m_vw, m_vw->initial_weights_config.num_bits);
191191

192192
auto buffer = std::make_shared<std::vector<char>>();
193193
{
@@ -225,7 +225,7 @@ void VowpalWabbitBase::ID::set(String^ value)
225225
}
226226

227227
void VowpalWabbitBase::SaveModel()
228-
{ std::string name = m_vw->final_regressor_name;
228+
{ std::string name = m_vw->output_model_config.final_regressor_name;
229229
if (name.empty())
230230
{ return;
231231
}

cs/cli/vw_example.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ bool VowpalWabbitExample::IsNewLine::get()
7373

7474
ILabel^ VowpalWabbitExample::Label::get()
7575
{ ILabel^ label;
76-
auto lp = m_owner->Native->m_vw->example_parser->lbl_parser;
76+
auto lp = m_owner->Native->m_vw->parser_runtime.example_parser->lbl_parser;
7777
if (!memcmp(&lp, &VW::simple_label_parser_global, sizeof(lp)))
7878
label = gcnew SimpleLabel();
7979
else if (!memcmp(&lp, &VW::cb_label_parser_global, sizeof(lp)))
@@ -103,7 +103,7 @@ void VowpalWabbitExample::Label::set(ILabel^ label)
103103
label->UpdateExample(m_owner->Native->m_vw, m_example);
104104

105105
// we need to update the example weight as setup_example() can be called prior to this call.
106-
m_example->weight = m_owner->Native->m_vw->example_parser->lbl_parser.get_weight(m_example->l, m_example->ex_reduction_features);
106+
m_example->weight = m_owner->Native->m_vw->parser_runtime.example_parser->lbl_parser.get_weight(m_example->l, m_example->ex_reduction_features);
107107
}
108108

109109
void VowpalWabbitExample::MakeEmpty(VowpalWabbit^ vw)
@@ -389,7 +389,7 @@ uint64_t VowpalWabbitFeature::WeightIndex::get()
389389
throw gcnew InvalidOperationException("VowpalWabbitFeature must be initialized with example");
390390

391391
VW::workspace* vw = m_example->Owner->Native->m_vw;
392-
return ((m_weight_index + m_example->m_example->ft_offset) >> vw->weights.stride_shift()) & vw->parse_mask;
392+
return ((m_weight_index + m_example->m_example->ft_offset) >> vw->weights.stride_shift()) & vw->runtime_state.parse_mask;
393393
}
394394

395395
float VowpalWabbitFeature::Weight::get()

cs/cli/vw_prediction.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,8 @@ cli::array<float>^ VowpalWabbitTopicPredictionFactory::Create(VW::workspace* vw,
155155
{ if (ex == nullptr)
156156
throw gcnew ArgumentNullException("ex");
157157

158-
auto values = gcnew cli::array<float>(vw->lda);
159-
Marshal::Copy(IntPtr(ex->pred.scalars.begin()), values, 0, vw->lda);
158+
auto values = gcnew cli::array<float>(vw->reduction_state.lda);
159+
Marshal::Copy(IntPtr(ex->pred.scalars.begin()), values, 0, vw->reduction_state.lda);
160160

161161
return values;
162162
}

cs/vw.net.native/vw.net.arguments.cc

+13-10
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
API void GetWorkspaceBasicArguments(
77
vw_net_native::workspace_context* workspace, vw_net_native::vw_basic_arguments_t* args)
88
{
9-
args->is_test_only = !workspace->vw->training;
10-
args->num_passes = (int)workspace->vw->numpasses;
11-
args->learning_rate = workspace->vw->eta;
12-
args->power_t = workspace->vw->power_t;
9+
args->is_test_only = !workspace->vw->runtime_config.training;
10+
args->num_passes = (int)workspace->vw->runtime_config.numpasses;
11+
args->learning_rate = workspace->vw->update_rule_config.eta;
12+
args->power_t = workspace->vw->update_rule_config.power_t;
1313

1414
if (workspace->vw->options->was_supplied("cb"))
1515
{
@@ -19,12 +19,12 @@ API void GetWorkspaceBasicArguments(
1919

2020
API const char* GetWorkspaceDataFilename(vw_net_native::workspace_context* workspace)
2121
{
22-
return workspace->vw->data_filename.c_str();
22+
return workspace->vw->parser_runtime.data_filename.c_str();
2323
}
2424

2525
API const char* GetFinalRegressorFilename(vw_net_native::workspace_context* workspace)
2626
{
27-
return workspace->vw->final_regressor_name.c_str();
27+
return workspace->vw->output_model_config.final_regressor_name.c_str();
2828
}
2929

3030
API char* SerializeCommandLine(vw_net_native::workspace_context* workspace)
@@ -42,20 +42,23 @@ API char* SerializeCommandLine(vw_net_native::workspace_context* workspace)
4242

4343
API size_t GetInitialRegressorFilenamesCount(vw_net_native::workspace_context* workspace)
4444
{
45-
return workspace->vw->initial_regressors.size();
45+
return workspace->vw->initial_weights_config.initial_regressors.size();
4646
}
4747

4848
API vw_net_native::dotnet_size_t GetInitialRegressorFilenames(
4949
vw_net_native::workspace_context* workspace, const char** filenames, vw_net_native::dotnet_size_t count)
5050
{
51-
std::vector<std::string>& initial_regressors = workspace->vw->initial_regressors;
51+
std::vector<std::string>& initial_regressors = workspace->vw->initial_weights_config.initial_regressors;
5252
size_t size = initial_regressors.size();
5353
if ((size_t)count < size)
5454
{
5555
return vw_net_native::size_to_neg_dotnet_size(size); // Not enough space in destination buffer
5656
}
5757

58-
for (size_t i = 0; i < size; i++) { filenames[i] = workspace->vw->initial_regressors[i].c_str(); }
58+
for (size_t i = 0; i < size; i++)
59+
{
60+
filenames[i] = workspace->vw->initial_weights_config.initial_regressors[i].c_str();
61+
}
5962

60-
return workspace->vw->initial_regressors.size();
63+
return workspace->vw->initial_weights_config.initial_regressors.size();
6164
}

cs/vw.net.native/vw.net.example.cc

+5-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
API VW::example* CreateExample(vw_net_native::workspace_context* workspace)
1111
{
1212
auto* ex = new VW::example;
13-
workspace->vw->example_parser->lbl_parser.default_label(ex->l);
13+
workspace->vw->parser_runtime.example_parser->lbl_parser.default_label(ex->l);
1414
return ex;
1515
}
1616

@@ -189,12 +189,13 @@ API void MakeIntoNewlineExample(vw_net_native::workspace_context* workspace, VW:
189189

190190
API void MakeLabelDefault(vw_net_native::workspace_context* workspace, VW::example* example)
191191
{
192-
workspace->vw->example_parser->lbl_parser.default_label(example->l);
192+
workspace->vw->parser_runtime.example_parser->lbl_parser.default_label(example->l);
193193
}
194194

195195
API void UpdateExampleWeight(vw_net_native::workspace_context* workspace, VW::example* example)
196196
{
197-
example->weight = workspace->vw->example_parser->lbl_parser.get_weight(example->l, example->ex_reduction_features);
197+
example->weight =
198+
workspace->vw->parser_runtime.example_parser->lbl_parser.get_weight(example->l, example->ex_reduction_features);
198199
}
199200

200201
API vw_net_native::namespace_enumerator* CreateNamespaceEnumerator(
@@ -256,7 +257,7 @@ API VW::feature_index GetShiftedWeightIndex(
256257
vw_net_native::workspace_context* workspace, VW::example* example, VW::feature_index feature_index)
257258
{
258259
VW::workspace* vw = workspace->vw;
259-
return ((feature_index + example->ft_offset) >> vw->weights.stride_shift()) & vw->parse_mask;
260+
return ((feature_index + example->ft_offset) >> vw->weights.stride_shift()) & vw->runtime_state.parse_mask;
260261
}
261262

262263
API float GetWeight(vw_net_native::workspace_context* workspace, VW::example* example, VW::feature_index feature_index)

cs/vw.net.native/vw.net.predictions.cc

+6-3
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,17 @@ API vw_net_native::dotnet_size_t GetPredictionActionScores(
8080
return vw_net_native::v_copy_to_managed(ex->pred.a_s, values, count);
8181
}
8282

83-
API size_t GetPredictionTopicProbsCount(VW::workspace* vw, VW::example* ex) { return static_cast<size_t>(vw->lda); }
83+
API size_t GetPredictionTopicProbsCount(VW::workspace* vw, VW::example* ex)
84+
{
85+
return static_cast<size_t>(vw->reduction_state.lda);
86+
}
8487

8588
API vw_net_native::dotnet_size_t GetPredictionTopicProbs(
8689
VW::workspace* vw, VW::example* ex, float* values, vw_net_native::dotnet_size_t count)
8790
{
88-
if (count < vw->lda)
91+
if (count < vw->reduction_state.lda)
8992
{
90-
return vw_net_native::size_to_neg_dotnet_size(vw->lda); // not enough space in the output array
93+
return vw_net_native::size_to_neg_dotnet_size(vw->reduction_state.lda); // not enough space in the output array
9194
}
9295

9396
const v_array<float>& scalars = ex->pred.scalars;

0 commit comments

Comments
 (0)