diff --git a/utl/flatbuffer/vw_to_flat.cc b/utl/flatbuffer/vw_to_flat.cc index 1245bed8796..11dba82eb59 100644 --- a/utl/flatbuffer/vw_to_flat.cc +++ b/utl/flatbuffer/vw_to_flat.cc @@ -363,7 +363,8 @@ void to_flat::convert_txt_to_flat(VW::workspace& all) MultiExampleBuilder multi_ex_builder; ExampleBuilder ex_builder; - VW::example* ae = all.example_parser->ready_parsed_examples.pop(); + VW::example* ae = nullptr; + all.example_parser->ready_parsed_examples.try_pop(ae); while (ae != nullptr && !ae->end_pass) { @@ -454,7 +455,8 @@ void to_flat::convert_txt_to_flat(VW::workspace& all) ex_builder.clear(); _multi_ex_index++; _examples++; - ae = all.example_parser->ready_parsed_examples.pop(); + ae = nullptr; + all.example_parser->ready_parsed_examples.try_pop(ae); continue; } else @@ -472,7 +474,8 @@ void to_flat::convert_txt_to_flat(VW::workspace& all) write_to_file(collection, all.l->is_multiline(), multi_ex_builder, ex_builder, outfile); - ae = all.example_parser->ready_parsed_examples.pop(); + ae = nullptr; + all.example_parser->ready_parsed_examples.try_pop(ae); } if (collection && _collection_count > 0) diff --git a/vowpalwabbit/core/include/vw/core/parser.h b/vowpalwabbit/core/include/vw/core/parser.h index 12affc4a495..7b96485349a 100644 --- a/vowpalwabbit/core/include/vw/core/parser.h +++ b/vowpalwabbit/core/include/vw/core/parser.h @@ -64,7 +64,7 @@ struct parser std::vector words; VW::object_pool example_pool; - VW::ptr_queue ready_parsed_examples; + VW::thread_safe_queue ready_parsed_examples; io_buf input; // Input source(s) diff --git a/vowpalwabbit/core/include/vw/core/queue.h b/vowpalwabbit/core/include/vw/core/queue.h index c1451ff1b38..e79a5f6859e 100644 --- a/vowpalwabbit/core/include/vw/core/queue.h +++ b/vowpalwabbit/core/include/vw/core/queue.h @@ -23,30 +23,30 @@ namespace VW { template -class ptr_queue +class thread_safe_queue { public: - ptr_queue(size_t max_size) : max_size(max_size) {} + thread_safe_queue(size_t max_size) : max_size(max_size) {} - T* pop() + bool try_pop(T& item) { std::unique_lock lock(mut); while (object_queue.size() == 0 && !done) { is_not_empty.wait(lock); } - if (done && object_queue.size() == 0) { return nullptr; } + if (done && object_queue.size() == 0) { return false; } - auto item = object_queue.front(); + item = std::move(object_queue.front()); object_queue.pop(); is_not_full.notify_all(); - return item; + return true; } - void push(T* item) + void push(T item) { std::unique_lock lock(mut); while (object_queue.size() == max_size) { is_not_full.wait(lock); } - object_queue.push(item); + object_queue.push(std::move(item)); is_not_empty.notify_all(); } @@ -69,7 +69,7 @@ class ptr_queue private: size_t max_size; - std::queue object_queue; + std::queue object_queue; mutable std::mutex mut; volatile bool done = false; diff --git a/vowpalwabbit/core/src/parser.cc b/vowpalwabbit/core/src/parser.cc index 68436aad69e..283a154ed23 100644 --- a/vowpalwabbit/core/src/parser.cc +++ b/vowpalwabbit/core/src/parser.cc @@ -913,14 +913,19 @@ void finish_example(VW::workspace& all, example& ec) void thread_dispatch(VW::workspace& all, const VW::multi_ex& examples) { - for (auto example : examples) { all.example_parser->ready_parsed_examples.push(example); } + for (auto* example : examples) { all.example_parser->ready_parsed_examples.push(example); } } void main_parse_loop(VW::workspace* all) { parse_dispatch(*all, thread_dispatch); } namespace VW { -example* get_example(parser* p) { return p->ready_parsed_examples.pop(); } +example* get_example(parser* p) +{ + example* ex = nullptr; + p->ready_parsed_examples.try_pop(ex); + return ex; +} float get_topic_prediction(example* ec, size_t i) { return ec->pred.scalars[i]; } @@ -980,9 +985,13 @@ void free_parser(VW::workspace& all) while (all.example_parser->ready_parsed_examples.size() > 0) { - auto* current = all.example_parser->ready_parsed_examples.pop(); - // this function also handles examples that were not from the pool. - VW::finish_example(all, *current); + VW::example* current = nullptr; + all.example_parser->ready_parsed_examples.try_pop(current); + if (current != nullptr) + { + // this function also handles examples that were not from the pool. + VW::finish_example(all, *current); + } } // There should be no examples in flight at this point.