diff --git a/cpp/src/arrow/csv/reader.cc b/cpp/src/arrow/csv/reader.cc index cf66b47e585..0f6f386029e 100644 --- a/cpp/src/arrow/csv/reader.cc +++ b/cpp/src/arrow/csv/reader.cc @@ -154,19 +154,12 @@ struct CSVBlock { template <> struct IterationTraits { static csv::CSVBlock End() { return csv::CSVBlock{{}, {}, {}, -1, true, {}}; } + static bool IsEnd(const csv::CSVBlock& val) { return val.block_index < 0; } }; namespace csv { namespace { -// The == operator must be defined to be used as T in Iterator -bool operator==(const CSVBlock& left, const CSVBlock& right) { - return left.block_index == right.block_index; -} -bool operator!=(const CSVBlock& left, const CSVBlock& right) { - return left.block_index != right.block_index; -} - // This is a callable that can be used to transform an iterator. The source iterator // will contain buffers of data and the output iterator will contain delimited CSV // blocks. util::optional is used so that there is an end token (required by the @@ -731,7 +724,7 @@ class SerialStreamingReader : public BaseStreamingReader { if (!source_eof_) { ARROW_ASSIGN_OR_RAISE(auto maybe_block, block_iterator_.Next()); - if (maybe_block != IterationTraits::End()) { + if (!IsIterationEnd(maybe_block)) { last_block_index_ = maybe_block.block_index; auto maybe_parsed = ParseAndInsert(maybe_block.partial, maybe_block.completion, maybe_block.buffer, maybe_block.block_index, @@ -813,7 +806,7 @@ class SerialTableReader : public BaseTableReader { RETURN_NOT_OK(stop_token_.Poll()); ARROW_ASSIGN_OR_RAISE(auto maybe_block, block_iterator.Next()); - if (maybe_block == IterationTraits::End()) { + if (IsIterationEnd(maybe_block)) { // EOF break; } @@ -865,7 +858,7 @@ class AsyncThreadedTableReader auto transferred_it = MakeTransferredGenerator(bg_it, cpu_executor_); - int32_t block_queue_size = std::max(2, cpu_executor_->GetCapacity()); + int32_t block_queue_size = cpu_executor_->GetCapacity(); auto rh_it = MakeSerialReadaheadGenerator(std::move(transferred_it), block_queue_size); buffer_generator_ = CSVBufferIterator::MakeAsync(std::move(rh_it)); diff --git a/cpp/src/arrow/testing/future_util.h b/cpp/src/arrow/testing/future_util.h index 3679c6b918d..44fa78c375c 100644 --- a/cpp/src/arrow/testing/future_util.h +++ b/cpp/src/arrow/testing/future_util.h @@ -26,7 +26,7 @@ // unit test anyways. #define ASSERT_FINISHES_IMPL(fut) \ do { \ - ASSERT_TRUE(fut.Wait(10)); \ + ASSERT_TRUE(fut.Wait(300)); \ if (!fut.is_finished()) { \ FAIL() << "Future did not finish in a timely fashion"; \ } \ @@ -35,11 +35,11 @@ #define ASSERT_FINISHES_OK(expr) \ do { \ auto&& _fut = (expr); \ - ASSERT_TRUE(_fut.Wait(10)); \ + ASSERT_TRUE(_fut.Wait(300)); \ if (!_fut.is_finished()) { \ FAIL() << "Future did not finish in a timely fashion"; \ } \ - auto _st = _fut.status(); \ + auto& _st = _fut.status(); \ if (!_st.ok()) { \ FAIL() << "'" ARROW_STRINGIFY(expr) "' failed with " << _st.ToString(); \ } \ diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index 462a5237921..0917149d014 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -49,8 +49,10 @@ #include "arrow/table.h" #include "arrow/type.h" #include "arrow/util/checked_cast.h" +#include "arrow/util/future.h" #include "arrow/util/io_util.h" #include "arrow/util/logging.h" +#include "arrow/util/windows_compatibility.h" namespace arrow { @@ -596,6 +598,33 @@ void SleepFor(double seconds) { std::chrono::nanoseconds(static_cast(seconds * 1e9))); } +#ifdef _WIN32 +void SleepABit() { + LARGE_INTEGER freq, start, now; + QueryPerformanceFrequency(&freq); + // 1 ms + auto desired = freq.QuadPart / 1000; + if (desired <= 0) { + // Fallback to STL sleep if high resolution clock not available, tests may fail, + // shouldn't really happen + SleepFor(1e-3); + return; + } + QueryPerformanceCounter(&start); + while (true) { + std::this_thread::yield(); + QueryPerformanceCounter(&now); + auto elapsed = now.QuadPart - start.QuadPart; + if (elapsed > desired) { + break; + } + } +} +#else +// std::this_thread::sleep_for should be high enough resolution on non-Windows systems +void SleepABit() { SleepFor(1e-3); } +#endif + void BusyWait(double seconds, std::function predicate) { const double period = 0.001; for (int i = 0; !predicate() && i * period < seconds; ++i) { @@ -603,6 +632,24 @@ void BusyWait(double seconds, std::function predicate) { } } +Future<> SleepAsync(double seconds) { + auto out = Future<>::Make(); + std::thread([out, seconds]() mutable { + SleepFor(seconds); + out.MarkFinished(Status::OK()); + }).detach(); + return out; +} + +Future<> SleepABitAsync() { + auto out = Future<>::Make(); + std::thread([out]() mutable { + SleepABit(); + out.MarkFinished(Status::OK()); + }).detach(); + return out; +} + /////////////////////////////////////////////////////////////////////////// // Extension types diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index 744af0e0f75..c3618a17151 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -434,10 +434,24 @@ inline void BitmapFromVector(const std::vector& is_valid, ARROW_TESTING_EXPORT void SleepFor(double seconds); +// Sleeps for a very small amount of time. The thread will be yielded +// at least once ensuring that context switches could happen. It is intended +// to be used for stress testing parallel code and shouldn't be assumed to do any +// reliable timing. +ARROW_TESTING_EXPORT +void SleepABit(); + // Wait until predicate is true or timeout in seconds expires. ARROW_TESTING_EXPORT void BusyWait(double seconds, std::function predicate); +ARROW_TESTING_EXPORT +Future<> SleepAsync(double seconds); + +// \see SleepABit +ARROW_TESTING_EXPORT +Future<> SleepABitAsync(); + template std::vector IteratorToVector(Iterator iterator) { EXPECT_OK_AND_ASSIGN(auto out, iterator.ToVector()); diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index 168e172bc88..46018ef13be 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -35,6 +35,12 @@ class Result; class Status; +namespace detail { +struct Empty; +} +template +class Future; + namespace util { class Codec; } // namespace util diff --git a/cpp/src/arrow/util/CMakeLists.txt b/cpp/src/arrow/util/CMakeLists.txt index 1f98bead0a4..37987b98520 100644 --- a/cpp/src/arrow/util/CMakeLists.txt +++ b/cpp/src/arrow/util/CMakeLists.txt @@ -41,6 +41,7 @@ endif() add_arrow_test(utility-test SOURCES align_util_test.cc + async_generator_test.cc bit_block_counter_test.cc bit_util_test.cc cache_test.cc @@ -60,6 +61,7 @@ add_arrow_test(utility-test stl_util_test.cc string_test.cc tdigest_test.cc + test_common.cc time_test.cc trie_test.cc uri_test.cc diff --git a/cpp/src/arrow/util/algorithm.h b/cpp/src/arrow/util/algorithm.h new file mode 100644 index 00000000000..2a0e6ba709d --- /dev/null +++ b/cpp/src/arrow/util/algorithm.h @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/result.h" + +namespace arrow { + +template +Status MaybeTransform(InputIterator first, InputIterator last, OutputIterator out, + UnaryOperation unary_op) { + for (; first != last; ++first, (void)++out) { + ARROW_ASSIGN_OR_RAISE(*out, unary_op(*first)); + } + return Status::OK(); +} + +} // namespace arrow diff --git a/cpp/src/arrow/util/async_generator.h b/cpp/src/arrow/util/async_generator.h index 29285fbd25c..7cb73f4ed87 100644 --- a/cpp/src/arrow/util/async_generator.h +++ b/cpp/src/arrow/util/async_generator.h @@ -24,7 +24,6 @@ #include "arrow/util/functional.h" #include "arrow/util/future.h" #include "arrow/util/iterator.h" -#include "arrow/util/logging.h" #include "arrow/util/mutex.h" #include "arrow/util/optional.h" #include "arrow/util/queue.h" @@ -32,20 +31,47 @@ namespace arrow { +// The methods in this file create, modify, and utilize AsyncGenerator which is an +// iterator of futures. This allows an asynchronous source (like file input) to be run +// through a pipeline in the same way that iterators can be used to create pipelined +// workflows. +// +// In order to support pipeline parallelism we introduce the concept of asynchronous +// reentrancy. This is different than synchronous reentrancy. With synchronous code a +// function is reentrant if the function can be called again while a previous call to that +// function is still running. Unless otherwise specified none of these generators are +// synchronously reentrant. Care should be taken to avoid calling them in such a way (and +// the utilities Visit/Collect/Await take care to do this). +// +// Asynchronous reentrancy on the other hand means the function is called again before the +// future returned by the function is marekd finished (but after the call to get the +// future returns). Some of these generators are async-reentrant while others (e.g. +// those that depend on ordered processing like decompression) are not. Read the MakeXYZ +// function comments to determine which generators support async reentrancy. +// +// Note: Generators that are not asynchronously reentrant can still support readahead +// (\see MakeSerialReadaheadGenerator). +// +// Readahead operators, and some other operators, may introduce queueing. Any operators +// that introduce buffering should detail the amount of buffering they introduce in their +// MakeXYZ function comments. template using AsyncGenerator = std::function()>; template -Future AsyncGeneratorEnd() { - return Future::MakeFinished(IterationTraits::End()); -} +struct IterationTraits> { + /// \brief by default when iterating through a sequence of AsyncGenerator, + /// an empty function indicates the end of iteration. + static AsyncGenerator End() { return AsyncGenerator(); } + + static bool IsEnd(const AsyncGenerator& val) { return !val; } +}; template -bool IsGeneratorEnd(const T& value) { - return value == IterationTraits::End(); +Future AsyncGeneratorEnd() { + return Future::MakeFinished(IterationTraits::End()); } -/// Iterates through a generator of futures, visiting the result of each one and /// returning a future that completes when all have been visited template Future<> VisitAsyncGenerator(AsyncGenerator generator, @@ -53,7 +79,7 @@ Future<> VisitAsyncGenerator(AsyncGenerator generator, struct LoopBody { struct Callback { Result> operator()(const T& result) { - if (result == IterationTraits::End()) { + if (IsIterationEnd(result)) { return Break(detail::Empty()); } else { auto visited = visitor(result); @@ -81,6 +107,14 @@ Future<> VisitAsyncGenerator(AsyncGenerator generator, return Loop(LoopBody{std::move(generator), std::move(visitor)}); } +/// \brief Waits for an async generator to complete, discarding results. +template +Future<> DiscardAllFromAsyncGenerator(AsyncGenerator generator) { + std::function visitor = [](...) { return Status::OK(); }; + return VisitAsyncGenerator(generator, visitor); +} + +/// \brief Collects the results of an async generator into a vector template Future> CollectAsyncGenerator(AsyncGenerator generator) { auto vec = std::make_shared>(); @@ -89,7 +123,7 @@ Future> CollectAsyncGenerator(AsyncGenerator generator) { auto next = generator_(); auto vec = vec_; return next.Then([vec](const T& result) -> Result>> { - if (result == IterationTraits::End()) { + if (IsIterationEnd(result)) { return Break(*vec); } else { vec->push_back(result); @@ -103,6 +137,150 @@ Future> CollectAsyncGenerator(AsyncGenerator generator) { return Loop(LoopBody{std::move(generator), std::move(vec)}); } +/// \see MakeMappedGenerator +template +class MappingGenerator { + public: + MappingGenerator(AsyncGenerator source, std::function(const T&)> map) + : state_(std::make_shared(std::move(source), std::move(map))) {} + + Future operator()() { + auto future = Future::Make(); + bool should_trigger; + { + auto guard = state_->mutex.Lock(); + if (state_->finished) { + return AsyncGeneratorEnd(); + } + should_trigger = state_->waiting_jobs.empty(); + state_->waiting_jobs.push_back(future); + } + if (should_trigger) { + state_->source().AddCallback(Callback{state_}); + } + return future; + } + + private: + struct State { + State(AsyncGenerator source, std::function(const T&)> map) + : source(std::move(source)), + map(std::move(map)), + waiting_jobs(), + mutex(), + finished(false) {} + + void Purge() { + // This might be called by an original callback (if the source iterator fails or + // ends) or by a mapped callback (if the map function fails or ends prematurely). + // Either way it should only be called once and after finished is set so there is no + // need to guard access to `waiting_jobs`. + while (!waiting_jobs.empty()) { + waiting_jobs.front().MarkFinished(IterationTraits::End()); + waiting_jobs.pop_front(); + } + } + + AsyncGenerator source; + std::function(const T&)> map; + std::deque> waiting_jobs; + util::Mutex mutex; + bool finished; + }; + + struct Callback; + + struct MappedCallback { + void operator()(const Result& maybe_next) { + bool end = !maybe_next.ok() || IsIterationEnd(*maybe_next); + bool should_purge = false; + if (end) { + { + auto guard = state->mutex.Lock(); + should_purge = !state->finished; + state->finished = true; + } + } + sink.MarkFinished(maybe_next); + if (should_purge) { + state->Purge(); + } + } + std::shared_ptr state; + Future sink; + }; + + struct Callback { + void operator()(const Result& maybe_next) { + Future sink; + bool end = !maybe_next.ok() || IsIterationEnd(*maybe_next); + bool should_purge = false; + bool should_trigger; + { + auto guard = state->mutex.Lock(); + if (end) { + should_purge = !state->finished; + state->finished = true; + } + sink = state->waiting_jobs.front(); + state->waiting_jobs.pop_front(); + should_trigger = !end && !state->waiting_jobs.empty(); + } + if (should_purge) { + state->Purge(); + } + if (should_trigger) { + state->source().AddCallback(Callback{state}); + } + if (maybe_next.ok()) { + const T& val = maybe_next.ValueUnsafe(); + if (IsIterationEnd(val)) { + sink.MarkFinished(IterationTraits::End()); + } else { + Future mapped_fut = state->map(val); + mapped_fut.AddCallback(MappedCallback{std::move(state), std::move(sink)}); + } + } else { + sink.MarkFinished(maybe_next.status()); + } + } + + std::shared_ptr state; + }; + + std::shared_ptr state_; +}; + +/// \brief Creates a generator that will apply the map function to each element of +/// source. The map function is not called on the end token. +/// +/// Note: This function makes a copy of `map` for each item +/// Note: Errors returned from the `map` function will be propagated +/// +/// If the source generator is async-reentrant then this generator will be also +template +AsyncGenerator MakeMappedGenerator(AsyncGenerator source_generator, + std::function(const T&)> map) { + std::function(const T&)> future_map = [map](const T& val) -> Future { + return Future::MakeFinished(map(val)); + }; + return MappingGenerator(std::move(source_generator), std::move(future_map)); +} +template +AsyncGenerator MakeMappedGenerator(AsyncGenerator source_generator, + std::function map) { + std::function(const T&)> maybe_future_map = [map](const T& val) -> Future { + return Future::MakeFinished(map(val)); + }; + return MappingGenerator(std::move(source_generator), std::move(maybe_future_map)); +} +template +AsyncGenerator MakeMappedGenerator(AsyncGenerator source_generator, + std::function(const T&)> map) { + return MappingGenerator(std::move(source_generator), std::move(map)); +} + +/// \see MakeAsyncGenerator template class TransformingGenerator { // The transforming generator state will be referenced as an async generator but will @@ -128,8 +306,8 @@ class TransformingGenerator { } auto next_fut = generator_(); - // If finished already, process results immediately inside the loop to avoid stack - // overflow + // If finished already, process results immediately inside the loop to avoid + // stack overflow if (next_fut.is_finished()) { auto next_result = next_fut.result(); if (next_result.ok()) { @@ -157,7 +335,7 @@ class TransformingGenerator { if (!finished_ && last_value_.has_value()) { ARROW_ASSIGN_OR_RAISE(TransformFlow next, transformer_(*last_value_)); if (next.ReadyForNext()) { - if (*last_value_ == IterationTraits::End()) { + if (IsIterationEnd(*last_value_)) { finished_ = true; } last_value_.reset(); @@ -193,6 +371,23 @@ class TransformingGenerator { std::shared_ptr state_; }; +/// \brief Transforms an async generator using a transformer function returning a new +/// AsyncGenerator +/// +/// The transform function here behaves exactly the same as the transform function in +/// MakeTransformedIterator and you can safely use the same transform function to +/// transform both synchronous and asynchronous streams. +/// +/// This generator is not async-reentrant +/// +/// This generator may queue up to 1 instance of T +template +AsyncGenerator MakeAsyncGenerator(AsyncGenerator generator, + Transformer transformer) { + return TransformingGenerator(generator, transformer); +} + +/// \see MakeSerialReadaheadGenerator template class SerialReadaheadGenerator { public: @@ -233,8 +428,10 @@ class SerialReadaheadGenerator { : first_(true), source_(std::move(source)), finished_(false), - spaces_available_(max_readahead), - readahead_queue_(max_readahead) {} + // There is one extra "space" for the in-flight request + spaces_available_(max_readahead + 1), + // The SPSC queue has size-1 "usable" slots so we need to overallocate 1 + readahead_queue_(max_readahead + 1) {} Status Pump(const std::shared_ptr& self) { // Can't do readahead_queue.write(source().Then(Callback{self})) because then the @@ -277,7 +474,7 @@ class SerialReadaheadGenerator { return maybe_next; } const auto& next = *maybe_next; - if (next == IterationTraits::End()) { + if (IsIterationEnd(next)) { state_->finished_.store(true); return maybe_next; } @@ -294,6 +491,21 @@ class SerialReadaheadGenerator { std::shared_ptr state_; }; +/// \brief Creates a generator that will pull from the source into a queue. Unlike +/// MakeReadaheadGenerator this will not pull reentrantly from the source. +/// +/// The source generator does not need to be async-reentrant +/// +/// This generator is not async-reentrant (even if the source is) +/// +/// This generator may queue up to max_readahead additional instances of T +template +AsyncGenerator MakeSerialReadaheadGenerator(AsyncGenerator source_generator, + int max_readahead) { + return SerialReadaheadGenerator(std::move(source_generator), max_readahead); +} + +/// \see MakeReadaheadGenerator template class ReadaheadGenerator { public: @@ -304,8 +516,7 @@ class ReadaheadGenerator { if (!next_result.ok()) { finished->store(true); } else { - const auto& next = *next_result; - if (next == IterationTraits::End()) { + if (IsIterationEnd(*next_result)) { *finished = true; } } @@ -449,41 +660,227 @@ class PushGenerator { /// The source generator must be async-reentrant /// /// This generator itself is async-reentrant. +/// +/// This generator may queue up to max_readahead instances of T template AsyncGenerator MakeReadaheadGenerator(AsyncGenerator source_generator, int max_readahead) { return ReadaheadGenerator(std::move(source_generator), max_readahead); } -/// \brief Creates a generator that will pull from the source into a queue. Unlike -/// MakeReadaheadGenerator this will not pull reentrantly from the source. -/// -/// The source generator does not need to be async-reentrant +/// \brief Creates a generator that will yield finished futures from a vector /// -/// This generator is not async-reentrant (even if the source is) +/// This generator is async-reentrant template -AsyncGenerator MakeSerialReadaheadGenerator(AsyncGenerator source_generator, - int max_readahead) { - return SerialReadaheadGenerator(std::move(source_generator), max_readahead); +AsyncGenerator MakeVectorGenerator(std::vector vec) { + struct State { + explicit State(std::vector vec_) : vec(std::move(vec_)), vec_idx(0) {} + + std::vector vec; + std::atomic vec_idx; + }; + + auto state = std::make_shared(std::move(vec)); + return [state]() { + auto idx = state->vec_idx.fetch_add(1); + if (idx >= state->vec.size()) { + return AsyncGeneratorEnd(); + } + return Future::MakeFinished(state->vec[idx]); + }; } -/// \brief Transforms an async generator using a transformer function returning a new -/// AsyncGenerator +/// \see MakeMergedGenerator +template +class MergedGenerator { + public: + explicit MergedGenerator(AsyncGenerator> source, + int max_subscriptions) + : state_(std::make_shared(std::move(source), max_subscriptions)) {} + + Future operator()() { + Future waiting_future; + std::shared_ptr delivered_job; + { + auto guard = state_->mutex.Lock(); + if (!state_->delivered_jobs.empty()) { + delivered_job = std::move(state_->delivered_jobs.front()); + state_->delivered_jobs.pop_front(); + } else if (state_->finished) { + return IterationTraits::End(); + } else { + waiting_future = Future::Make(); + state_->waiting_jobs.push_back(std::make_shared>(waiting_future)); + } + } + if (delivered_job) { + // deliverer will be invalid if outer callback encounters an error and delivers a + // failed result + if (delivered_job->deliverer) { + delivered_job->deliverer().AddCallback( + InnerCallback{state_, delivered_job->index}); + } + return std::move(delivered_job->value); + } + if (state_->first) { + state_->first = false; + for (std::size_t i = 0; i < state_->active_subscriptions.size(); i++) { + state_->source().AddCallback(OuterCallback{state_, i}); + } + } + return waiting_future; + } + + private: + struct DeliveredJob { + explicit DeliveredJob(AsyncGenerator deliverer_, Result value_, + std::size_t index_) + : deliverer(deliverer_), value(std::move(value_)), index(index_) {} + + AsyncGenerator deliverer; + Result value; + std::size_t index; + }; + + struct State { + State(AsyncGenerator> source, int max_subscriptions) + : source(std::move(source)), + active_subscriptions(max_subscriptions), + delivered_jobs(), + waiting_jobs(), + mutex(), + first(true), + source_exhausted(false), + finished(false), + num_active_subscriptions(max_subscriptions) {} + + AsyncGenerator> source; + // active_subscriptions and delivered_jobs will be bounded by max_subscriptions + std::vector> active_subscriptions; + std::deque> delivered_jobs; + // waiting_jobs is unbounded, reentrant pulls (e.g. AddReadahead) will provide the + // backpressure + std::deque>> waiting_jobs; + util::Mutex mutex; + bool first; + bool source_exhausted; + bool finished; + int num_active_subscriptions; + }; + + struct InnerCallback { + void operator()(const Result& maybe_next) { + Future sink; + bool sub_finished = maybe_next.ok() && IsIterationEnd(*maybe_next); + { + auto guard = state->mutex.Lock(); + if (state->finished) { + // We've errored out so just ignore this result and don't keep pumping + return; + } + if (!sub_finished) { + if (state->waiting_jobs.empty()) { + state->delivered_jobs.push_back(std::make_shared( + state->active_subscriptions[index], maybe_next, index)); + } else { + sink = std::move(*state->waiting_jobs.front()); + state->waiting_jobs.pop_front(); + } + } + } + if (sub_finished) { + state->source().AddCallback(OuterCallback{state, index}); + } else if (sink.is_valid()) { + sink.MarkFinished(maybe_next); + if (maybe_next.ok()) { + state->active_subscriptions[index]().AddCallback(*this); + } + } + } + std::shared_ptr state; + std::size_t index; + }; + + struct OuterCallback { + void operator()(const Result>& maybe_next) { + bool should_purge = false; + bool should_continue = false; + Future error_sink; + { + auto guard = state->mutex.Lock(); + if (!maybe_next.ok() || IsIterationEnd(*maybe_next)) { + state->source_exhausted = true; + if (!maybe_next.ok() || --state->num_active_subscriptions == 0) { + state->finished = true; + should_purge = true; + } + if (!maybe_next.ok()) { + if (state->waiting_jobs.empty()) { + state->delivered_jobs.push_back(std::make_shared( + AsyncGenerator(), maybe_next.status(), index)); + } else { + error_sink = std::move(*state->waiting_jobs.front()); + state->waiting_jobs.pop_front(); + } + } + } else { + state->active_subscriptions[index] = *maybe_next; + should_continue = true; + } + } + if (error_sink.is_valid()) { + error_sink.MarkFinished(maybe_next.status()); + } + if (should_continue) { + (*maybe_next)().AddCallback(InnerCallback{state, index}); + } else if (should_purge) { + // At this point state->finished has been marked true so no one else + // will be interacting with waiting_jobs and we can iterate outside lock + while (!state->waiting_jobs.empty()) { + state->waiting_jobs.front()->MarkFinished(IterationTraits::End()); + state->waiting_jobs.pop_front(); + } + } + } + std::shared_ptr state; + std::size_t index; + }; + + std::shared_ptr state_; +}; + +/// \brief Creates a generator that takes in a stream of generators and pulls from up to +/// max_subscriptions at a time /// -/// The transform function here behaves exactly the same as the transform function in -/// MakeTransformedIterator and you can safely use the same transform function to -/// transform both synchronous and asynchronous streams. +/// Note: This may deliver items out of sequence. For example, items from the third +/// AsyncGenerator generated by the source may be emitted before some items from the first +/// AsyncGenerator generated by the source. /// -/// This generator is not async-reentrant -template -AsyncGenerator MakeAsyncGenerator(AsyncGenerator generator, - Transformer transformer) { - return TransformingGenerator(generator, transformer); +/// This generator will pull from source async-reentrantly unless max_subscriptions is 1 +/// This generator will not pull from the individual subscriptions reentrantly. Add +/// readahead to the individual subscriptions if that is desired. +/// This generator is async-reentrant +/// +/// This generator may queue up to max_subscriptions instances of T +template +AsyncGenerator MakeMergedGenerator(AsyncGenerator> source, + int max_subscriptions) { + return MergedGenerator(std::move(source), max_subscriptions); } -/// \brief Transfers execution of the generator onto the given executor +/// \brief Creates a generator that takes in a stream of generators and pulls from each +/// one in sequence. /// -/// This generator is async-reentrant if the source generator is async-reentrant +/// This generator is async-reentrant but will never pull from source reentrantly and +/// will never pull from any subscription reentrantly. +/// +/// This generator may queue 1 instance of T +template +AsyncGenerator MakeConcatenatedGenerator(AsyncGenerator> source) { + return MergedGenerator(std::move(source), 1); +} + +/// \see MakeTransferredGenerator template class TransferringGenerator { public: @@ -508,16 +905,45 @@ class TransferringGenerator { /// /// Keep in mind that continuations called on an already completed future will /// always be run synchronously and so no transfer will happen in that case. +/// +/// This generator is async reentrant if the source is +/// +/// This generator will not queue template AsyncGenerator MakeTransferredGenerator(AsyncGenerator source, internal::Executor* executor) { return TransferringGenerator(std::move(source), executor); } +/// \see MakeIteratorGenerator +template +class IteratorGenerator { + public: + explicit IteratorGenerator(Iterator it) : it_(std::move(it)) {} + + Future operator()() { return Future::MakeFinished(it_.Next()); } -/// \brief Async generator that iterates on an underlying iterator in a -/// separate executor. + private: + Iterator it_; +}; + +/// \brief Constructs a generator that yields futures from an iterator. /// -/// This generator is async-reentrant +/// Note: Do not use this if you can avoid it. This blocks in an async +/// context which is a bad idea. If you're converting sync-I/O to async +/// then use MakeBackgroundGenerator. Otherwise, convert the underlying +/// source to async. This function is only around until we can conver the +/// remaining table readers to async. Once all uses of this generator have +/// been removed it should be removed(ARROW-11909). +/// +/// This generator is not async-reentrant +/// +/// This generator will not queue +template +AsyncGenerator MakeIteratorGenerator(Iterator it) { + return IteratorGenerator(std::move(it)); +} + +/// \see MakeBackgroundGenerator template class BackgroundGenerator { public: @@ -552,7 +978,7 @@ class BackgroundGenerator { return IterationTraits::End(); } auto next = it_->Next(); - if (!next.ok() || *next == IterationTraits::End()) { + if (!next.ok() || IsIterationEnd(*next)) { *done_ = true; } return next; @@ -570,6 +996,10 @@ class BackgroundGenerator { /// \brief Creates an AsyncGenerator by iterating over an Iterator on a background /// thread +/// +/// This generator is async-reentrant +/// +/// This generator will not queue template static Result> MakeBackgroundGenerator( Iterator iterator, internal::Executor* io_executor) { @@ -578,8 +1008,7 @@ static Result> MakeBackgroundGenerator( return [background_iterator]() { return (*background_iterator)(); }; } -/// \brief Converts an AsyncGenerator to an Iterator by blocking until each future -/// is finished +/// \see MakeGeneratorIterator template class GeneratorIterator { public: @@ -591,11 +1020,19 @@ class GeneratorIterator { AsyncGenerator source_; }; +/// \brief Converts an AsyncGenerator to an Iterator by blocking until each future +/// is finished template Result> MakeGeneratorIterator(AsyncGenerator source) { return Iterator(GeneratorIterator(std::move(source))); } +/// \brief Adds readahead to an iterator using a background thread. +/// +/// Under the hood this is converting the iterator to a generator using +/// MakeBackgroundGenerator, adding readahead to the converted generator with +/// MakeReadaheadGenerator, and then converting back to an iterator using +/// MakeGeneratorIterator. template Result> MakeReadaheadIterator(Iterator it, int readahead_queue_size) { ARROW_ASSIGN_OR_RAISE(auto io_executor, internal::ThreadPool::Make(1)); diff --git a/cpp/src/arrow/util/async_generator_test.cc b/cpp/src/arrow/util/async_generator_test.cc new file mode 100644 index 00000000000..4eaec0a592d --- /dev/null +++ b/cpp/src/arrow/util/async_generator_test.cc @@ -0,0 +1,943 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include + +#include "arrow/testing/future_util.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/type_fwd.h" +#include "arrow/util/async_generator.h" +#include "arrow/util/test_common.h" +#include "arrow/util/vector.h" + +namespace arrow { + +template +AsyncGenerator AsyncVectorIt(std::vector v) { + return MakeVectorGenerator(std::move(v)); +} + +template +AsyncGenerator FailsAt(AsyncGenerator src, int failing_index) { + auto index = std::make_shared>(0); + return [src, index, failing_index]() { + auto idx = index->fetch_add(1); + if (idx >= failing_index) { + return Future::MakeFinished(Status::Invalid("XYZ")); + } + return src(); + }; +} + +template +AsyncGenerator SlowdownABit(AsyncGenerator source) { + return MakeMappedGenerator(std::move(source), [](const T& res) -> Future { + return SleepABitAsync().Then( + [res](const Result& empty) { return res; }); + }); +} + +template +class TrackingGenerator { + public: + explicit TrackingGenerator(AsyncGenerator source) + : state_(std::make_shared(std::move(source))) {} + + Future operator()() { + state_->num_read++; + return state_->source(); + } + + int num_read() { return state_->num_read; } + + private: + struct State { + explicit State(AsyncGenerator source) : source(std::move(source)), num_read(0) {} + + AsyncGenerator source; + int num_read; + }; + + std::shared_ptr state_; +}; + +// Yields items with a small pause between each one from a background thread +std::function()> BackgroundAsyncVectorIt(std::vector v, + bool sleep = true) { + auto pool = internal::GetCpuThreadPool(); + auto iterator = VectorIt(v); + auto slow_iterator = MakeTransformedIterator( + std::move(iterator), [sleep](TestInt item) -> Result> { + if (sleep) { + SleepABit(); + } + return TransformYield(item); + }); + + EXPECT_OK_AND_ASSIGN(auto background, + MakeBackgroundGenerator(std::move(slow_iterator), + internal::GetCpuThreadPool())); + return MakeTransferredGenerator(background, pool); +} + +template +void AssertAsyncGeneratorMatch(std::vector expected, AsyncGenerator actual) { + auto vec_future = CollectAsyncGenerator(std::move(actual)); + EXPECT_OK_AND_ASSIGN(auto vec, vec_future.result()); + EXPECT_EQ(expected, vec); +} + +template +void AssertGeneratorExhausted(AsyncGenerator& gen) { + ASSERT_FINISHES_OK_AND_ASSIGN(auto next, gen()); + ASSERT_TRUE(IsIterationEnd(next)); +} + +// -------------------------------------------------------------------- +// Asynchronous iterator tests + +template +class ReentrantCheckerGuard; + +template +ReentrantCheckerGuard ExpectNotAccessedReentrantly(AsyncGenerator* generator); + +template +class ReentrantChecker { + public: + Future operator()() { + if (state_->generated_unfinished_future.load()) { + state_->valid.store(false); + } + state_->generated_unfinished_future.store(true); + auto result = state_->source(); + return result.Then(Callback{state_}); + } + + bool valid() { return state_->valid.load(); } + + private: + explicit ReentrantChecker(AsyncGenerator source) + : state_(std::make_shared(std::move(source))) {} + + friend ReentrantCheckerGuard ExpectNotAccessedReentrantly( + AsyncGenerator* generator); + + struct State { + explicit State(AsyncGenerator source_) + : source(std::move(source_)), generated_unfinished_future(false), valid(true) {} + + AsyncGenerator source; + std::atomic generated_unfinished_future; + std::atomic valid; + }; + struct Callback { + Future operator()(const Result& result) { + state_->generated_unfinished_future.store(false); + return result; + } + std::shared_ptr state_; + }; + + std::shared_ptr state_; +}; + +template +class ReentrantCheckerGuard { + public: + explicit ReentrantCheckerGuard(ReentrantChecker checker) : checker_(checker) {} + + ARROW_DISALLOW_COPY_AND_ASSIGN(ReentrantCheckerGuard); + ReentrantCheckerGuard(ReentrantCheckerGuard&& other) : checker_(other.checker_) { + if (other.owner_) { + other.owner_ = false; + owner_ = true; + } else { + owner_ = false; + } + } + ReentrantCheckerGuard& operator=(ReentrantCheckerGuard&& other) { + checker_ = other.checker_; + if (other.owner_) { + other.owner_ = false; + owner_ = true; + } else { + owner_ = false; + } + return *this; + } + + ~ReentrantCheckerGuard() { + if (owner_ && !checker_.valid()) { + ADD_FAILURE() << "A generator was accessed reentrantly when the test asserted it " + "should not be."; + } + } + + private: + ReentrantChecker checker_; + bool owner_ = true; +}; + +template +ReentrantCheckerGuard ExpectNotAccessedReentrantly(AsyncGenerator* generator) { + auto reentrant_checker = ReentrantChecker(*generator); + *generator = reentrant_checker; + return ReentrantCheckerGuard(reentrant_checker); +} + +TEST(TestAsyncUtil, Visit) { + auto generator = AsyncVectorIt({1, 2, 3}); + unsigned int sum = 0; + auto sum_future = VisitAsyncGenerator(generator, [&sum](TestInt item) { + sum += item.value; + return Status::OK(); + }); + ASSERT_TRUE(sum_future.is_finished()); + ASSERT_EQ(6, sum); +} + +TEST(TestAsyncUtil, Collect) { + std::vector expected = {1, 2, 3}; + auto generator = AsyncVectorIt(expected); + auto collected = CollectAsyncGenerator(generator); + ASSERT_FINISHES_OK_AND_ASSIGN(auto collected_val, collected); + ASSERT_EQ(expected, collected_val); +} + +TEST(TestAsyncUtil, Map) { + std::vector input = {1, 2, 3}; + auto generator = AsyncVectorIt(input); + std::function mapper = [](const TestInt& in) { + return std::to_string(in.value); + }; + auto mapped = MakeMappedGenerator(std::move(generator), mapper); + std::vector expected{"1", "2", "3"}; + AssertAsyncGeneratorMatch(expected, mapped); +} + +TEST(TestAsyncUtil, MapAsync) { + std::vector input = {1, 2, 3}; + auto generator = AsyncVectorIt(input); + std::function(const TestInt&)> mapper = [](const TestInt& in) { + return SleepAsync(1e-3).Then([in](const Result& empty) { + return TestStr(std::to_string(in.value)); + }); + }; + auto mapped = MakeMappedGenerator(std::move(generator), mapper); + std::vector expected{"1", "2", "3"}; + AssertAsyncGeneratorMatch(expected, mapped); +} + +TEST(TestAsyncUtil, MapReentrant) { + std::vector input = {1, 2}; + auto source = AsyncVectorIt(input); + TrackingGenerator tracker(std::move(source)); + source = MakeTransferredGenerator(AsyncGenerator(tracker), + internal::GetCpuThreadPool()); + + std::atomic map_tasks_running(0); + // Mapper blocks until can_proceed is marked finished, should start multiple map tasks + Future<> can_proceed = Future<>::Make(); + std::function(const TestInt&)> mapper = [&](const TestInt& in) { + map_tasks_running.fetch_add(1); + return can_proceed.Then([in](...) { return TestStr(std::to_string(in.value)); }); + }; + auto mapped = MakeMappedGenerator(std::move(source), mapper); + + EXPECT_EQ(0, tracker.num_read()); + + auto one = mapped(); + auto two = mapped(); + + BusyWait(10, [&] { return map_tasks_running.load() == 2; }); + EXPECT_EQ(2, map_tasks_running.load()); + EXPECT_EQ(2, tracker.num_read()); + + auto end_one = mapped(); + auto end_two = mapped(); + + can_proceed.MarkFinished(); + ASSERT_FINISHES_OK_AND_ASSIGN(auto oneval, one); + EXPECT_EQ("1", oneval.value); + ASSERT_FINISHES_OK_AND_ASSIGN(auto twoval, two); + EXPECT_EQ("2", twoval.value); + ASSERT_FINISHES_OK_AND_ASSIGN(auto end, end_one); + ASSERT_EQ(IterationTraits::End(), end); + ASSERT_FINISHES_OK_AND_ASSIGN(end, end_two); + ASSERT_EQ(IterationTraits::End(), end); +} + +TEST(TestAsyncUtil, MapParallelStress) { + constexpr int NTASKS = 10; + constexpr int NITEMS = 10; + for (int i = 0; i < NTASKS; i++) { + auto gen = MakeVectorGenerator(RangeVector(NITEMS)); + gen = SlowdownABit(std::move(gen)); + auto guard = ExpectNotAccessedReentrantly(&gen); + std::function mapper = [](const TestInt& in) { + SleepABit(); + return std::to_string(in.value); + }; + auto mapped = MakeMappedGenerator(std::move(gen), mapper); + mapped = MakeReadaheadGenerator(mapped, 8); + ASSERT_FINISHES_OK_AND_ASSIGN(auto collected, CollectAsyncGenerator(mapped)); + ASSERT_EQ(NITEMS, collected.size()); + } +} + +TEST(TestAsyncUtil, MapTaskFail) { + std::vector input = {1, 2, 3}; + auto generator = AsyncVectorIt(input); + std::function(const TestInt&)> mapper = + [](const TestInt& in) -> Result { + if (in.value == 2) { + return Status::Invalid("XYZ"); + } + return TestStr(std::to_string(in.value)); + }; + auto mapped = MakeMappedGenerator(std::move(generator), mapper); + ASSERT_FINISHES_AND_RAISES(Invalid, CollectAsyncGenerator(mapped)); +} + +TEST(TestAsyncUtil, MapSourceFail) { + std::vector input = {1, 2, 3}; + auto generator = FailsAt(AsyncVectorIt(input), 1); + std::function(const TestInt&)> mapper = + [](const TestInt& in) -> Result { + return TestStr(std::to_string(in.value)); + }; + auto mapped = MakeMappedGenerator(std::move(generator), mapper); + ASSERT_FINISHES_AND_RAISES(Invalid, CollectAsyncGenerator(mapped)); +} + +TEST(TestAsyncUtil, Concatenated) { + std::vector inputOne{1, 2, 3}; + std::vector inputTwo{4, 5, 6}; + std::vector expected{1, 2, 3, 4, 5, 6}; + auto gen = AsyncVectorIt>( + {AsyncVectorIt(inputOne), AsyncVectorIt(inputTwo)}); + auto concat = MakeConcatenatedGenerator(gen); + AssertAsyncGeneratorMatch(expected, concat); +} + +class GeneratorTestFixture : public ::testing::TestWithParam { + protected: + AsyncGenerator MakeSource(const std::vector& items) { + std::vector wrapped(items.begin(), items.end()); + auto gen = AsyncVectorIt(std::move(wrapped)); + bool slow = GetParam(); + if (slow) { + return SlowdownABit(std::move(gen)); + } + return gen; + } + + AsyncGenerator MakeFailingSource() { + AsyncGenerator gen = [] { + return Future::MakeFinished(Status::Invalid("XYZ")); + }; + bool slow = GetParam(); + if (slow) { + return SlowdownABit(std::move(gen)); + } + return gen; + } + + int GetNumItersForStress() { + bool slow = GetParam(); + // Run fewer trials for the slow case since they take longer + if (slow) { + return 10; + } else { + return 100; + } + } +}; + +TEST_P(GeneratorTestFixture, Merged) { + auto gen = AsyncVectorIt>( + {MakeSource({1, 2, 3}), MakeSource({4, 5, 6})}); + + auto concat_gen = MakeMergedGenerator(gen, 10); + ASSERT_FINISHES_OK_AND_ASSIGN(auto concat, CollectAsyncGenerator(concat_gen)); + auto concat_ints = + internal::MapVector([](const TestInt& val) { return val.value; }, concat); + std::set concat_set(concat_ints.begin(), concat_ints.end()); + + std::set expected{1, 2, 4, 3, 5, 6}; + ASSERT_EQ(expected, concat_set); +} + +TEST_P(GeneratorTestFixture, MergedInnerFail) { + auto gen = AsyncVectorIt>( + {MakeSource({1, 2, 3}), MakeFailingSource()}); + auto merged_gen = MakeMergedGenerator(gen, 10); + ASSERT_FINISHES_AND_RAISES(Invalid, CollectAsyncGenerator(merged_gen)); +} + +TEST_P(GeneratorTestFixture, MergedOuterFail) { + auto gen = + FailsAt(AsyncVectorIt>( + {MakeSource({1, 2, 3}), MakeSource({1, 2, 3}), MakeSource({1, 2, 3})}), + 1); + auto merged_gen = MakeMergedGenerator(gen, 10); + ASSERT_FINISHES_AND_RAISES(Invalid, CollectAsyncGenerator(merged_gen)); +} + +TEST_P(GeneratorTestFixture, MergedLimitedSubscriptions) { + auto gen = AsyncVectorIt>( + {MakeSource({1, 2}), MakeSource({3, 4}), MakeSource({5, 6, 7, 8}), + MakeSource({9, 10, 11, 12})}); + TrackingGenerator> tracker(std::move(gen)); + auto merged = MakeMergedGenerator(AsyncGenerator>(tracker), 2); + + SleepABit(); + // Lazy pull, should not start until first pull + ASSERT_EQ(0, tracker.num_read()); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto next, merged()); + ASSERT_TRUE(next.value == 1 || next.value == 3); + + // First 2 values have to come from one of the first 2 sources + ASSERT_EQ(2, tracker.num_read()); + ASSERT_FINISHES_OK_AND_ASSIGN(next, merged()); + ASSERT_LT(next.value, 5); + ASSERT_GT(next.value, 0); + + // By the time five values have been read we should have exhausted at + // least one source + for (int i = 0; i < 3; i++) { + ASSERT_FINISHES_OK_AND_ASSIGN(next, merged()); + // 9 is possible if we read 1,2,3,4 and then grab 9 while 5 is running slow + ASSERT_LT(next.value, 10); + ASSERT_GT(next.value, 0); + } + ASSERT_GT(tracker.num_read(), 2); + ASSERT_LT(tracker.num_read(), 5); + + // Read remaining values + for (int i = 0; i < 7; i++) { + ASSERT_FINISHES_OK_AND_ASSIGN(next, merged()); + ASSERT_LT(next.value, 13); + ASSERT_GT(next.value, 0); + } + + AssertGeneratorExhausted(merged); +} + +TEST_P(GeneratorTestFixture, MergedStress) { + constexpr int NGENERATORS = 10; + constexpr int NITEMS = 10; + for (int i = 0; i < GetNumItersForStress(); i++) { + std::vector> sources; + std::vector> guards; + for (int j = 0; j < NGENERATORS; j++) { + auto source = MakeSource(RangeVector(NITEMS)); + guards.push_back(ExpectNotAccessedReentrantly(&source)); + sources.push_back(source); + } + AsyncGenerator> source_gen = AsyncVectorIt(sources); + + auto merged = MakeMergedGenerator(source_gen, 4); + ASSERT_FINISHES_OK_AND_ASSIGN(auto items, CollectAsyncGenerator(merged)); + ASSERT_EQ(NITEMS * NGENERATORS, items.size()); + } +} + +TEST_P(GeneratorTestFixture, MergedParallelStress) { + constexpr int NGENERATORS = 10; + constexpr int NITEMS = 10; + for (int i = 0; i < GetNumItersForStress(); i++) { + std::vector> sources; + for (int j = 0; j < NGENERATORS; j++) { + sources.push_back(MakeSource(RangeVector(NITEMS))); + } + auto merged = MakeMergedGenerator(AsyncVectorIt(sources), 4); + merged = MakeReadaheadGenerator(merged, 4); + ASSERT_FINISHES_OK_AND_ASSIGN(auto items, CollectAsyncGenerator(merged)); + ASSERT_EQ(NITEMS * NGENERATORS, items.size()); + } +} + +INSTANTIATE_TEST_SUITE_P(GeneratorTests, GeneratorTestFixture, + ::testing::Values(false, true)); + +TEST(TestAsyncUtil, FromVector) { + AsyncGenerator gen; + { + std::vector input = {1, 2, 3}; + gen = MakeVectorGenerator(std::move(input)); + } + std::vector expected = {1, 2, 3}; + AssertAsyncGeneratorMatch(expected, gen); +} + +TEST(TestAsyncUtil, SynchronousFinish) { + AsyncGenerator generator = []() { + return Future::MakeFinished(IterationTraits::End()); + }; + Transformer skip_all = [](TestInt value) { return TransformSkip(); }; + auto transformed = MakeAsyncGenerator(generator, skip_all); + auto future = CollectAsyncGenerator(transformed); + ASSERT_FINISHES_OK_AND_ASSIGN(auto actual, future); + ASSERT_EQ(std::vector(), actual); +} + +TEST(TestAsyncUtil, GeneratorIterator) { + auto generator = BackgroundAsyncVectorIt({1, 2, 3}); + ASSERT_OK_AND_ASSIGN(auto iterator, MakeGeneratorIterator(std::move(generator))); + ASSERT_OK_AND_EQ(TestInt(1), iterator.Next()); + ASSERT_OK_AND_EQ(TestInt(2), iterator.Next()); + ASSERT_OK_AND_EQ(TestInt(3), iterator.Next()); + AssertIteratorExhausted(iterator); + AssertIteratorExhausted(iterator); +} + +TEST(TestAsyncUtil, MakeTransferredGenerator) { + std::mutex mutex; + std::condition_variable cv; + std::atomic finished(false); + + ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(1)); + + // Needs to be a slow source to ensure we don't call Then on a completed + AsyncGenerator slow_generator = [&]() { + return thread_pool + ->Submit([&] { + std::unique_lock lock(mutex); + cv.wait_for(lock, std::chrono::duration(30), + [&] { return finished.load(); }); + return IterationTraits::End(); + }) + .ValueOrDie(); + }; + + auto transferred = + MakeTransferredGenerator(std::move(slow_generator), thread_pool.get()); + + auto current_thread_id = std::this_thread::get_id(); + auto fut = transferred().Then([¤t_thread_id](const Result& result) { + ASSERT_NE(current_thread_id, std::this_thread::get_id()); + }); + + { + std::lock_guard lg(mutex); + finished.store(true); + } + cv.notify_one(); + ASSERT_FINISHES_OK(fut); +} + +// This test is too slow for valgrind +#if !(defined(ARROW_VALGRIND) || defined(ADDRESS_SANITIZER)) + +TEST(TestAsyncUtil, StackOverflow) { + int counter = 0; + AsyncGenerator generator = [&counter]() { + if (counter < 10000) { + return Future::MakeFinished(counter++); + } else { + return Future::MakeFinished(IterationTraits::End()); + } + }; + Transformer discard = + [](TestInt next) -> Result> { return TransformSkip(); }; + auto transformed = MakeAsyncGenerator(generator, discard); + auto collected_future = CollectAsyncGenerator(transformed); + ASSERT_FINISHES_OK_AND_ASSIGN(auto collected, collected_future); + ASSERT_EQ(0, collected.size()); +} + +#endif + +TEST(TestAsyncUtil, Background) { + std::vector expected = {1, 2, 3}; + auto background = BackgroundAsyncVectorIt(expected); + auto future = CollectAsyncGenerator(background); + ASSERT_FINISHES_OK_AND_ASSIGN(auto collected, future); + ASSERT_EQ(expected, collected); +} + +struct SlowEmptyIterator { + Result Next() { + if (called_) { + return Status::Invalid("Should not have been called twice"); + } + SleepFor(0.1); + return IterationTraits::End(); + } + + private: + bool called_ = false; +}; + +TEST(TestAsyncUtil, BackgroundRepeatEnd) { + // Ensure that the background generator properly fulfills the asyncgenerator contract + // and can be called after it ends. + ASSERT_OK_AND_ASSIGN(auto io_pool, internal::ThreadPool::Make(1)); + + auto iterator = Iterator(SlowEmptyIterator()); + ASSERT_OK_AND_ASSIGN(auto background_gen, + MakeBackgroundGenerator(std::move(iterator), io_pool.get())); + + background_gen = + MakeTransferredGenerator(std::move(background_gen), internal::GetCpuThreadPool()); + + auto one = background_gen(); + auto two = background_gen(); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto one_fin, one); + ASSERT_TRUE(IsIterationEnd(one_fin)); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto two_fin, two); + ASSERT_TRUE(IsIterationEnd(two_fin)); +} + +TEST(TestAsyncUtil, CompleteBackgroundStressTest) { + auto expected = RangeVector(20); + std::vector>> futures; + for (unsigned int i = 0; i < 20; i++) { + auto background = BackgroundAsyncVectorIt(expected); + futures.push_back(CollectAsyncGenerator(background)); + } + auto combined = All(futures); + ASSERT_FINISHES_OK_AND_ASSIGN(auto completed_vectors, combined); + for (std::size_t i = 0; i < completed_vectors.size(); i++) { + ASSERT_OK_AND_ASSIGN(auto vector, completed_vectors[i]); + ASSERT_EQ(vector, expected); + } +} + +TEST(TestAsyncUtil, SerialReadaheadSlowProducer) { + AsyncGenerator gen = BackgroundAsyncVectorIt({1, 2, 3, 4, 5}); + auto guard = ExpectNotAccessedReentrantly(&gen); + SerialReadaheadGenerator serial_readahead(gen, 2); + AssertAsyncGeneratorMatch({1, 2, 3, 4, 5}, + static_cast>(serial_readahead)); +} + +TEST(TestAsyncUtil, SerialReadaheadSlowConsumer) { + int num_delivered = 0; + auto source = [&num_delivered]() { + if (num_delivered < 5) { + return Future::MakeFinished(num_delivered++); + } else { + return Future::MakeFinished(IterationTraits::End()); + } + }; + AsyncGenerator serial_readahead = SerialReadaheadGenerator(source, 3); + SleepABit(); + ASSERT_EQ(0, num_delivered); + ASSERT_FINISHES_OK_AND_ASSIGN(auto next, serial_readahead()); + ASSERT_EQ(0, next.value); + ASSERT_EQ(4, num_delivered); + AssertAsyncGeneratorMatch({1, 2, 3, 4}, serial_readahead); + + // Ensure still reads ahead with just 1 slot + num_delivered = 0; + serial_readahead = SerialReadaheadGenerator(source, 1); + ASSERT_FINISHES_OK_AND_ASSIGN(next, serial_readahead()); + ASSERT_EQ(0, next.value); + ASSERT_EQ(2, num_delivered); + AssertAsyncGeneratorMatch({1, 2, 3, 4}, serial_readahead); +} + +TEST(TestAsyncUtil, SerialReadaheadStress) { + constexpr int NTASKS = 20; + constexpr int NITEMS = 50; + for (int i = 0; i < NTASKS; i++) { + AsyncGenerator gen = BackgroundAsyncVectorIt(RangeVector(NITEMS)); + auto guard = ExpectNotAccessedReentrantly(&gen); + SerialReadaheadGenerator serial_readahead(gen, 2); + auto visit_fut = + VisitAsyncGenerator(serial_readahead, [](TestInt test_int) -> Status { + // Normally sleeping in a visit function would be a faux-pas but we want to slow + // the reader down to match the producer to maximize the stress + SleepABit(); + return Status::OK(); + }); + ASSERT_FINISHES_OK(visit_fut); + } +} + +TEST(TestAsyncUtil, SerialReadaheadStressFast) { + constexpr int NTASKS = 20; + constexpr int NITEMS = 50; + for (int i = 0; i < NTASKS; i++) { + AsyncGenerator gen = BackgroundAsyncVectorIt(RangeVector(NITEMS), false); + auto guard = ExpectNotAccessedReentrantly(&gen); + SerialReadaheadGenerator serial_readahead(gen, 2); + auto visit_fut = VisitAsyncGenerator( + serial_readahead, [](TestInt test_int) -> Status { return Status::OK(); }); + ASSERT_FINISHES_OK(visit_fut); + } +} + +TEST(TestAsyncUtil, SerialReadaheadStressFailing) { + constexpr int NTASKS = 20; + constexpr int NITEMS = 50; + constexpr int EXPECTED_SUM = 45; + for (int i = 0; i < NTASKS; i++) { + AsyncGenerator it = BackgroundAsyncVectorIt(RangeVector(NITEMS)); + AsyncGenerator fails_at_ten = [&it]() { + auto next = it(); + return next.Then([](const Result& item) -> Result { + if (item->value >= 10) { + return Status::Invalid("XYZ"); + } else { + return item; + } + }); + }; + SerialReadaheadGenerator serial_readahead(fails_at_ten, 2); + unsigned int sum = 0; + auto visit_fut = VisitAsyncGenerator(serial_readahead, + [&sum](TestInt test_int) -> Status { + sum += test_int.value; + // Sleep to maximize stress + SleepABit(); + return Status::OK(); + }); + ASSERT_FINISHES_AND_RAISES(Invalid, visit_fut); + ASSERT_EQ(EXPECTED_SUM, sum); + } +} + +TEST(TestAsyncUtil, Readahead) { + int num_delivered = 0; + auto source = [&num_delivered]() { + if (num_delivered < 5) { + return Future::MakeFinished(num_delivered++); + } else { + return Future::MakeFinished(IterationTraits::End()); + } + }; + auto readahead = MakeReadaheadGenerator(source, 10); + // Should not pump until first item requested + ASSERT_EQ(0, num_delivered); + + auto first = readahead(); + // At this point the pumping should have happened + ASSERT_EQ(5, num_delivered); + ASSERT_FINISHES_OK_AND_ASSIGN(auto first_val, first); + ASSERT_EQ(TestInt(0), first_val); + + // Read the rest + for (int i = 0; i < 4; i++) { + auto next = readahead(); + ASSERT_FINISHES_OK_AND_ASSIGN(auto next_val, next); + ASSERT_EQ(TestInt(i + 1), next_val); + } + + // Next should be end + auto last = readahead(); + ASSERT_FINISHES_OK_AND_ASSIGN(auto last_val, last); + ASSERT_TRUE(IsIterationEnd(last_val)); +} + +TEST(TestAsyncUtil, ReadaheadFailed) { + ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(4)); + std::atomic counter(0); + // All tasks are a little slow. The first task fails. + // The readahead will have spawned 9 more tasks and they + // should all pass + auto source = [thread_pool, &counter]() -> Future { + auto count = counter++; + return *thread_pool->Submit([count]() -> Result { + if (count == 0) { + return Status::Invalid("X"); + } + return TestInt(count); + }); + }; + auto readahead = MakeReadaheadGenerator(source, 10); + ASSERT_FINISHES_AND_RAISES(Invalid, readahead()); + SleepABit(); + + for (int i = 0; i < 9; i++) { + ASSERT_FINISHES_OK_AND_ASSIGN(auto next_val, readahead()); + ASSERT_EQ(TestInt(i + 1), next_val); + } + ASSERT_FINISHES_OK_AND_ASSIGN(auto after, readahead()); + + // It's possible that finished was set quickly and there + // are only 10 elements + if (IsIterationEnd(after)) { + return; + } + + // It's also possible that finished was too slow and there + // ended up being 11 elements + ASSERT_EQ(TestInt(10), after); + // There can't be 12 elements because SleepABit will prevent it + ASSERT_FINISHES_OK_AND_ASSIGN(auto definitely_last, readahead()); + ASSERT_TRUE(IsIterationEnd(definitely_last)); +} + +TEST(TestAsyncIteratorTransform, SkipSome) { + auto original = AsyncVectorIt({1, 2, 3}); + auto filter = MakeFilter([](TestInt& t) { return t.value != 2; }); + auto filtered = MakeAsyncGenerator(std::move(original), filter); + AssertAsyncGeneratorMatch({"1", "3"}, std::move(filtered)); +} + +TEST(PushGenerator, Empty) { + PushGenerator gen; + auto producer = gen.producer(); + + auto fut = gen(); + AssertNotFinished(fut); + producer.Close(); + ASSERT_FINISHES_OK_AND_EQ(IterationTraits::End(), fut); + ASSERT_FINISHES_OK_AND_EQ(IterationTraits::End(), gen()); + + // Close idempotent + fut = gen(); + producer.Close(); + ASSERT_FINISHES_OK_AND_EQ(IterationTraits::End(), fut); + ASSERT_FINISHES_OK_AND_EQ(IterationTraits::End(), gen()); + ASSERT_FINISHES_OK_AND_EQ(IterationTraits::End(), gen()); +} + +TEST(PushGenerator, Success) { + PushGenerator gen; + auto producer = gen.producer(); + std::vector> futures; + + producer.Push(TestInt{1}); + producer.Push(TestInt{2}); + for (int i = 0; i < 3; ++i) { + futures.push_back(gen()); + } + ASSERT_FINISHES_OK_AND_EQ(TestInt{1}, futures[0]); + ASSERT_FINISHES_OK_AND_EQ(TestInt{2}, futures[1]); + AssertNotFinished(futures[2]); + + producer.Push(TestInt{3}); + ASSERT_FINISHES_OK_AND_EQ(TestInt{3}, futures[2]); + producer.Push(TestInt{4}); + futures.push_back(gen()); + ASSERT_FINISHES_OK_AND_EQ(TestInt{4}, futures[3]); + producer.Push(TestInt{5}); + producer.Close(); + for (int i = 0; i < 4; ++i) { + futures.push_back(gen()); + } + ASSERT_FINISHES_OK_AND_EQ(TestInt{5}, futures[4]); + for (int i = 5; i < 8; ++i) { + ASSERT_FINISHES_OK_AND_EQ(IterationTraits::End(), futures[i]); + } + ASSERT_FINISHES_OK_AND_EQ(IterationTraits::End(), gen()); +} + +TEST(PushGenerator, Errors) { + PushGenerator gen; + auto producer = gen.producer(); + std::vector> futures; + + producer.Push(TestInt{1}); + producer.Push(Status::Invalid("2")); + for (int i = 0; i < 3; ++i) { + futures.push_back(gen()); + } + ASSERT_FINISHES_OK_AND_EQ(TestInt{1}, futures[0]); + ASSERT_FINISHES_AND_RAISES(Invalid, futures[1]); + AssertNotFinished(futures[2]); + + producer.Push(Status::IOError("3")); + producer.Push(TestInt{4}); + ASSERT_FINISHES_AND_RAISES(IOError, futures[2]); + futures.push_back(gen()); + ASSERT_FINISHES_OK_AND_EQ(TestInt{4}, futures[3]); + producer.Close(); + ASSERT_FINISHES_OK_AND_EQ(IterationTraits::End(), gen()); +} + +TEST(PushGenerator, CloseEarly) { + PushGenerator gen; + auto producer = gen.producer(); + std::vector> futures; + + producer.Push(TestInt{1}); + producer.Push(TestInt{2}); + for (int i = 0; i < 3; ++i) { + futures.push_back(gen()); + } + producer.Close(); + producer.Push(TestInt{3}); + + ASSERT_FINISHES_OK_AND_EQ(TestInt{1}, futures[0]); + ASSERT_FINISHES_OK_AND_EQ(TestInt{2}, futures[1]); + ASSERT_FINISHES_OK_AND_EQ(IterationTraits::End(), futures[2]); + ASSERT_FINISHES_OK_AND_EQ(IterationTraits::End(), gen()); +} + +TEST(PushGenerator, Stress) { + const int NTHREADS = 20; + const int NVALUES = 2000; + const int NFUTURES = NVALUES + 100; + + PushGenerator gen; + auto producer = gen.producer(); + + std::atomic next_value{0}; + + auto producer_worker = [&]() { + while (true) { + int v = next_value.fetch_add(1); + if (v >= NVALUES) { + break; + } + producer.Push(v); + } + }; + + auto producer_main = [&]() { + std::vector threads; + for (int i = 0; i < NTHREADS; ++i) { + threads.emplace_back(producer_worker); + } + for (auto& thread : threads) { + thread.join(); + } + producer.Close(); + }; + + std::vector> results; + std::thread thread(producer_main); + for (int i = 0; i < NFUTURES; ++i) { + results.push_back(gen().result()); + } + thread.join(); + + std::unordered_set seen_values; + for (int i = 0; i < NVALUES; ++i) { + ASSERT_OK_AND_ASSIGN(auto v, results[i]); + ASSERT_EQ(seen_values.count(v.value), 0); + seen_values.insert(v.value); + } + for (int i = NVALUES; i < NFUTURES; ++i) { + ASSERT_OK_AND_EQ(IterationTraits::End(), results[i]); + } +} + +} // namespace arrow diff --git a/cpp/src/arrow/util/future.h b/cpp/src/arrow/util/future.h index 4ede2912e6d..376a2c3ca85 100644 --- a/cpp/src/arrow/util/future.h +++ b/cpp/src/arrow/util/future.h @@ -27,6 +27,7 @@ #include "arrow/result.h" #include "arrow/status.h" +#include "arrow/type_fwd.h" #include "arrow/util/functional.h" #include "arrow/util/macros.h" #include "arrow/util/optional.h" diff --git a/cpp/src/arrow/util/iterator.h b/cpp/src/arrow/util/iterator.h index 771b209a406..568cb1f5cd1 100644 --- a/cpp/src/arrow/util/iterator.h +++ b/cpp/src/arrow/util/iterator.h @@ -44,9 +44,23 @@ struct IterationTraits { /// \brief a reserved value which indicates the end of iteration. By /// default this is NULLPTR since most iterators yield pointer types. /// Specialize IterationTraits if different end semantics are required. + /// + /// Note: This should not be used to determine if a given value is a + /// terminal value. Use IsIterationEnd (which uses IsEnd) instead. This + /// is only for returning terminal values. static T End() { return T(NULLPTR); } + + /// \brief Checks to see if the value is a terminal value. + /// A method is used here since T is not neccesarily comparable in many + /// cases even though it has a distinct final value + static bool IsEnd(const T& val) { return val == End(); } }; +template +bool IsIterationEnd(const T& val) { + return IterationTraits::IsEnd(val); +} + template struct IterationTraits> { /// \brief by default when iterating through a sequence of optional, @@ -54,6 +68,11 @@ struct IterationTraits> { /// Specialize IterationTraits if different end semantics are required. static util::optional End() { return util::nullopt; } + /// \brief by default when iterating through a sequence of optional, + /// nullopt (!has_value()) indicates the end of iteration. + /// Specialize IterationTraits if different end semantics are required. + static bool IsEnd(const util::optional& val) { return !val.has_value(); } + // TODO(bkietz) The range-for loop over Iterator> yields // Result> which is unnecessary (since only the unyielded end optional // is nullopt. Add IterationTraits::GetRangeElement() to handle this case @@ -90,12 +109,10 @@ class Iterator : public util::EqualityComparable> { /// returned by the visitor, terminating iteration. template Status Visit(Visitor&& visitor) { - const auto end = IterationTraits::End(); - for (;;) { ARROW_ASSIGN_OR_RAISE(auto value, Next()); - if (value == end) break; + if (IsIterationEnd(value)) break; ARROW_RETURN_NOT_OK(visitor(std::move(value))); } @@ -266,7 +283,7 @@ class TransformIterator { } auto next = *next_res; if (next.ReadyForNext()) { - if (*last_value_ == IterationTraits::End()) { + if (IsIterationEnd(*last_value_)) { finished_ = true; } last_value_.reset(); @@ -314,6 +331,7 @@ struct IterationTraits> { // The end condition for an Iterator of Iterators is a default constructed (null) // Iterator. static Iterator End() { return Iterator(); } + static bool IsEnd(const Iterator& val) { return !val; } }; template @@ -405,7 +423,7 @@ class MapIterator { Result Next() { ARROW_ASSIGN_OR_RAISE(I i, it_.Next()); - if (i == IterationTraits::End()) { + if (IsIterationEnd(i)) { return IterationTraits::End(); } @@ -467,7 +485,7 @@ struct FilterIterator { for (;;) { ARROW_ASSIGN_OR_RAISE(From i, it_.Next()); - if (i == IterationTraits::End()) { + if (IsIterationEnd(i)) { return IterationTraits::End(); } @@ -503,12 +521,12 @@ class FlattenIterator { explicit FlattenIterator(Iterator> it) : parent_(std::move(it)) {} Result Next() { - if (child_ == IterationTraits>::End()) { + if (IsIterationEnd(child_)) { // Pop from parent's iterator. ARROW_ASSIGN_OR_RAISE(child_, parent_.Next()); // Check if final iteration reached. - if (child_ == IterationTraits>::End()) { + if (IsIterationEnd(child_)) { return IterationTraits::End(); } @@ -517,7 +535,7 @@ class FlattenIterator { // Pop from child_ and check for depletion. ARROW_ASSIGN_OR_RAISE(T out, child_.Next()); - if (out == IterationTraits::End()) { + if (IsIterationEnd(out)) { // Reset state such that we pop from parent on the recursive call child_ = IterationTraits>::End(); diff --git a/cpp/src/arrow/util/iterator_test.cc b/cpp/src/arrow/util/iterator_test.cc index 0cd8767bf87..60b57dea1e2 100644 --- a/cpp/src/arrow/util/iterator_test.cc +++ b/cpp/src/arrow/util/iterator_test.cc @@ -26,57 +26,13 @@ #include #include -#include "arrow/testing/future_util.h" #include "arrow/testing/gtest_util.h" #include "arrow/util/async_generator.h" #include "arrow/util/iterator.h" - +#include "arrow/util/test_common.h" +#include "arrow/util/vector.h" namespace arrow { -struct TestInt { - TestInt() : value(-999) {} - TestInt(int i) : value(i) {} // NOLINT runtime/explicit - int value; - - bool operator==(const TestInt& other) const { return value == other.value; } - - friend std::ostream& operator<<(std::ostream& os, const TestInt& v) { - os << "{" << v.value << "}"; - return os; - } -}; - -template <> -struct IterationTraits { - static TestInt End() { return TestInt(); } -}; - -struct TestStr { - TestStr() : value("") {} - TestStr(const std::string& s) : value(s) {} // NOLINT runtime/explicit - TestStr(const char* s) : value(s) {} // NOLINT runtime/explicit - explicit TestStr(const TestInt& test_int) { - if (test_int == IterationTraits::End()) { - value = ""; - } else { - value = std::to_string(test_int.value); - } - } - std::string value; - - bool operator==(const TestStr& other) const { return value == other.value; } - - friend std::ostream& operator<<(std::ostream& os, const TestStr& v) { - os << "{\"" << v.value << "\"}"; - return os; - } -}; - -template <> -struct IterationTraits { - static TestStr End() { return TestStr(); } -}; - template class TracingIterator { public: @@ -157,54 +113,12 @@ template inline Iterator EmptyIt() { return MakeEmptyIterator(); } + +// Non-templated version of VectorIt to allow better type deduction inline Iterator VectorIt(std::vector v) { return MakeVectorIterator(std::move(v)); } -AsyncGenerator AsyncVectorIt(std::vector v) { - size_t index = 0; - return [index, v]() mutable -> Future { - if (index >= v.size()) { - return Future::MakeFinished(IterationTraits::End()); - } - return Future::MakeFinished(v[index++]); - }; -} - -constexpr auto kYieldDuration = std::chrono::microseconds(50); - -// Yields items with a small pause between each one from a background thread -std::function()> BackgroundAsyncVectorIt(std::vector v, - bool sleep = true) { - auto pool = internal::GetCpuThreadPool(); - auto iterator = VectorIt(v); - auto slow_iterator = MakeTransformedIterator( - std::move(iterator), [sleep](TestInt item) -> Result> { - if (sleep) { - std::this_thread::sleep_for(kYieldDuration); - } - return TransformYield(item); - }); - - EXPECT_OK_AND_ASSIGN(auto background, - MakeBackgroundGenerator(std::move(slow_iterator), - internal::GetCpuThreadPool())); - return MakeTransferredGenerator(background, pool); -} - -std::vector RangeVector(unsigned int max) { - std::vector range(max); - for (unsigned int i = 0; i < max; i++) { - range[i] = i; - } - return range; -} - -template -inline Iterator VectorIt(std::vector v) { - return MakeVectorIterator(std::move(v)); -} - template inline Iterator FilterIt(Iterator it, Fn&& fn) { return MakeFilterIterator(std::forward(fn), std::move(it)); @@ -220,13 +134,6 @@ void AssertIteratorMatch(std::vector expected, Iterator actual) { EXPECT_EQ(expected, IteratorToVector(std::move(actual))); } -template -void AssertAsyncGeneratorMatch(std::vector expected, AsyncGenerator actual) { - auto vec_future = CollectAsyncGenerator(std::move(actual)); - EXPECT_OK_AND_ASSIGN(auto vec, vec_future.result()); - EXPECT_EQ(expected, vec); -} - template void AssertIteratorNoMatch(std::vector expected, Iterator actual) { EXPECT_NE(expected, IteratorToVector(std::move(actual))); @@ -238,11 +145,6 @@ void AssertIteratorNext(T expected, Iterator& it) { ASSERT_EQ(expected, actual); } -template -void AssertIteratorExhausted(Iterator& it) { - AssertIteratorNext(IterationTraits::End(), it); -} - // -------------------------------------------------------------------- // Synchronous iterator tests @@ -336,16 +238,6 @@ TEST(TestIteratorTransform, TruncatingShort) { AssertIteratorMatch({"1"}, std::move(truncated)); } -Transformer MakeFilter(std::function filter) { - return [filter](TestInt next) -> Result> { - if (filter(next)) { - return TransformYield(TestStr(next)); - } else { - return TransformSkip(); - } - }; -} - TEST(TestIteratorTransform, SkipSome) { // Exercises TransformSkip auto original = VectorIt({1, 2, 3}); @@ -378,7 +270,7 @@ TEST(TestIteratorTransform, Abort) { ASSERT_OK(transformed.Next()); ASSERT_RAISES(Invalid, transformed.Next()); ASSERT_OK_AND_ASSIGN(auto third, transformed.Next()); - ASSERT_EQ(IterationTraits::End(), third); + ASSERT_TRUE(IsIterationEnd(third)); } template @@ -499,10 +391,6 @@ TEST(ReadaheadIterator, NotExhausted) { AssertIteratorNext({2}, it); } -void SleepABit(double seconds = 1e-3) { - std::this_thread::sleep_for(std::chrono::duration(seconds)); -} - TEST(ReadaheadIterator, Trace) { TracingIterator tracing_it(VectorIt({1, 2, 3, 4, 5, 6, 7, 8})); auto tracing = tracing_it.state(); @@ -573,513 +461,4 @@ TEST(ReadaheadIterator, NextError) { AssertIteratorExhausted(it); } -// -------------------------------------------------------------------- -// Asynchronous iterator tests - -TEST(TestAsyncUtil, Visit) { - auto generator = AsyncVectorIt({1, 2, 3}); - unsigned int sum = 0; - auto sum_future = VisitAsyncGenerator(generator, [&sum](TestInt item) { - sum += item.value; - return Status::OK(); - }); - ASSERT_TRUE(sum_future.is_finished()); - ASSERT_EQ(6, sum); -} - -TEST(TestAsyncUtil, Collect) { - std::vector expected = {1, 2, 3}; - auto generator = AsyncVectorIt(expected); - auto collected = CollectAsyncGenerator(generator); - ASSERT_FINISHES_OK_AND_ASSIGN(auto collected_val, collected); - ASSERT_EQ(expected, collected_val); -} - -TEST(TestAsyncUtil, SynchronousFinish) { - AsyncGenerator generator = []() { - return Future::MakeFinished(IterationTraits::End()); - }; - Transformer skip_all = [](TestInt value) { return TransformSkip(); }; - auto transformed = MakeAsyncGenerator(generator, skip_all); - auto future = CollectAsyncGenerator(transformed); - ASSERT_TRUE(future.is_finished()); - ASSERT_OK_AND_ASSIGN(auto actual, future.result()); - ASSERT_EQ(std::vector(), actual); -} - -TEST(TestAsyncUtil, GeneratorIterator) { - auto generator = BackgroundAsyncVectorIt({1, 2, 3}); - ASSERT_OK_AND_ASSIGN(auto iterator, MakeGeneratorIterator(std::move(generator))); - ASSERT_OK_AND_EQ(TestInt(1), iterator.Next()); - ASSERT_OK_AND_EQ(TestInt(2), iterator.Next()); - ASSERT_OK_AND_EQ(TestInt(3), iterator.Next()); - ASSERT_OK_AND_EQ(IterationTraits::End(), iterator.Next()); - ASSERT_OK_AND_EQ(IterationTraits::End(), iterator.Next()); -} - -TEST(TestAsyncUtil, MakeTransferredGenerator) { - std::mutex mutex; - std::condition_variable cv; - std::atomic finished(false); - - ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(1)); - - // Needs to be a slow source to ensure we don't call Then on a completed - AsyncGenerator slow_generator = [&]() { - return thread_pool - ->Submit([&] { - std::unique_lock lock(mutex); - cv.wait_for(lock, std::chrono::duration(30), - [&] { return finished.load(); }); - return IterationTraits::End(); - }) - .ValueOrDie(); - }; - - auto transferred = - MakeTransferredGenerator(std::move(slow_generator), thread_pool.get()); - - auto current_thread_id = std::this_thread::get_id(); - auto fut = transferred().Then([¤t_thread_id](const Result& result) { - ASSERT_NE(current_thread_id, std::this_thread::get_id()); - }); - - { - std::lock_guard lg(mutex); - finished.store(true); - } - cv.notify_one(); - ASSERT_FINISHES_OK(fut); -} - -// This test is too slow for valgrind -#if !(defined(ARROW_VALGRIND) || defined(ADDRESS_SANITIZER)) - -TEST(TestAsyncUtil, StackOverflow) { - int counter = 0; - AsyncGenerator generator = [&counter]() { - if (counter < 1000000) { - return Future::MakeFinished(counter++); - } else { - return Future::MakeFinished(IterationTraits::End()); - } - }; - Transformer discard = - [](TestInt next) -> Result> { return TransformSkip(); }; - auto transformed = MakeAsyncGenerator(generator, discard); - auto collected_future = CollectAsyncGenerator(transformed); - ASSERT_FINISHES_OK_AND_ASSIGN(auto collected, collected_future); - ASSERT_EQ(0, collected.size()); -} - -#endif - -TEST(TestAsyncUtil, Background) { - std::vector expected = {1, 2, 3}; - auto background = BackgroundAsyncVectorIt(expected); - auto future = CollectAsyncGenerator(background); - ASSERT_FINISHES_OK_AND_ASSIGN(auto collected, future); - ASSERT_EQ(expected, collected); -} - -struct SlowEmptyIterator { - Result Next() { - if (called_) { - return Status::Invalid("Should not have been called twice"); - } - SleepFor(0.1); - return IterationTraits::End(); - } - - private: - bool called_ = false; -}; - -TEST(TestAsyncUtil, BackgroundRepeatEnd) { - // Ensure that the background generator properly fulfills the asyncgenerator contract - // and can be called after it ends. - ASSERT_OK_AND_ASSIGN(auto io_pool, internal::ThreadPool::Make(1)); - - auto iterator = Iterator(SlowEmptyIterator()); - ASSERT_OK_AND_ASSIGN(auto background_gen, - MakeBackgroundGenerator(std::move(iterator), io_pool.get())); - - background_gen = - MakeTransferredGenerator(std::move(background_gen), internal::GetCpuThreadPool()); - - auto one = background_gen(); - auto two = background_gen(); - - ASSERT_FINISHES_OK_AND_ASSIGN(auto one_fin, one); - ASSERT_EQ(IterationTraits::End(), one_fin); - - ASSERT_FINISHES_OK_AND_ASSIGN(auto two_fin, two); - ASSERT_EQ(IterationTraits::End(), two_fin); -} - -TEST(TestAsyncUtil, CompleteBackgroundStressTest) { - auto expected = RangeVector(20); - std::vector>> futures; - for (unsigned int i = 0; i < 20; i++) { - auto background = BackgroundAsyncVectorIt(expected); - futures.push_back(CollectAsyncGenerator(background)); - } - auto combined = All(futures); - ASSERT_FINISHES_OK_AND_ASSIGN(auto completed_vectors, combined); - for (std::size_t i = 0; i < completed_vectors.size(); i++) { - ASSERT_OK_AND_ASSIGN(auto vector, completed_vectors[i]); - ASSERT_EQ(vector, expected); - } -} - -template -class ReentrantChecker { - public: - explicit ReentrantChecker(AsyncGenerator source) - : state_(std::make_shared(std::move(source))) {} - - Future operator()() { - if (state_->in.load()) { - state_->valid.store(false); - } - state_->in.store(true); - auto result = state_->source(); - return result.Then(Callback{state_}); - } - - void AssertValid() { - EXPECT_EQ(true, state_->valid.load()) - << "The generator was accessed in a reentrant manner"; - } - - private: - struct State { - explicit State(AsyncGenerator source_) - : source(std::move(source_)), in(false), valid(true) {} - - AsyncGenerator source; - std::atomic in; - std::atomic valid; - }; - struct Callback { - Future operator()(const Result& result) { - state_->in.store(false); - return result; - } - std::shared_ptr state_; - }; - - std::shared_ptr state_; -}; - -TEST(TestAsyncUtil, SerialReadaheadSlowProducer) { - AsyncGenerator it = BackgroundAsyncVectorIt({1, 2, 3, 4, 5}); - ReentrantChecker checker(std::move(it)); - SerialReadaheadGenerator serial_readahead(checker, 2); - AssertAsyncGeneratorMatch({1, 2, 3, 4, 5}, - static_cast>(serial_readahead)); - checker.AssertValid(); -} - -TEST(TestAsyncUtil, SerialReadaheadSlowConsumer) { - int num_delivered = 0; - auto source = [&num_delivered]() { - if (num_delivered < 5) { - return Future::MakeFinished(num_delivered++); - } else { - return Future::MakeFinished(IterationTraits::End()); - } - }; - SerialReadaheadGenerator serial_readahead(std::move(source), 3); - SleepABit(); - ASSERT_EQ(0, num_delivered); - ASSERT_FINISHES_OK_AND_ASSIGN(auto next, serial_readahead()); - ASSERT_EQ(0, next.value); - ASSERT_EQ(3, num_delivered); - AssertAsyncGeneratorMatch({1, 2, 3, 4}, - static_cast>(serial_readahead)); -} - -TEST(TestAsyncUtil, SerialReadaheadStress) { - constexpr int NTASKS = 20; - constexpr int NITEMS = 50; - for (int i = 0; i < NTASKS; i++) { - AsyncGenerator it = BackgroundAsyncVectorIt(RangeVector(NITEMS)); - ReentrantChecker checker(std::move(it)); - SerialReadaheadGenerator serial_readahead(checker, 2); - auto visit_fut = - VisitAsyncGenerator(serial_readahead, [](TestInt test_int) -> Status { - // Normally sleeping in a visit function would be a faux-pas but we want to slow - // the reader down to match the producer to maximize the stress - std::this_thread::sleep_for(kYieldDuration); - return Status::OK(); - }); - ASSERT_FINISHES_OK(visit_fut); - checker.AssertValid(); - } -} - -TEST(TestAsyncUtil, SerialReadaheadStressFast) { - constexpr int NTASKS = 20; - constexpr int NITEMS = 50; - for (int i = 0; i < NTASKS; i++) { - AsyncGenerator it = BackgroundAsyncVectorIt(RangeVector(NITEMS), false); - ReentrantChecker checker(std::move(it)); - SerialReadaheadGenerator serial_readahead(checker, 2); - auto visit_fut = VisitAsyncGenerator( - serial_readahead, [](TestInt test_int) -> Status { return Status::OK(); }); - ASSERT_FINISHES_OK(visit_fut); - checker.AssertValid(); - } -} - -TEST(TestAsyncUtil, SerialReadaheadStressFailing) { - constexpr int NTASKS = 20; - constexpr int NITEMS = 50; - constexpr int EXPECTED_SUM = 45; - for (int i = 0; i < NTASKS; i++) { - AsyncGenerator it = BackgroundAsyncVectorIt(RangeVector(NITEMS)); - AsyncGenerator fails_at_ten = [&it]() { - auto next = it(); - return next.Then([](const Result& item) -> Result { - if (item->value >= 10) { - return Status::Invalid("XYZ"); - } else { - return item; - } - }); - }; - SerialReadaheadGenerator serial_readahead(fails_at_ten, 2); - unsigned int sum = 0; - auto visit_fut = VisitAsyncGenerator( - serial_readahead, [&sum](TestInt test_int) -> Status { - sum += test_int.value; - // Normally sleeping in a visit function would be a faux-pas but we want to slow - // the reader down to match the producer to maximize the stress - std::this_thread::sleep_for(kYieldDuration); - return Status::OK(); - }); - ASSERT_FINISHES_AND_RAISES(Invalid, visit_fut); - ASSERT_EQ(EXPECTED_SUM, sum); - } -} - -TEST(TestAsyncUtil, Readahead) { - int num_delivered = 0; - auto source = [&num_delivered]() { - if (num_delivered < 5) { - return Future::MakeFinished(num_delivered++); - } else { - return Future::MakeFinished(IterationTraits::End()); - } - }; - auto readahead = MakeReadaheadGenerator(source, 10); - // Should not pump until first item requested - ASSERT_EQ(0, num_delivered); - - auto first = readahead(); - // At this point the pumping should have happened - ASSERT_EQ(5, num_delivered); - ASSERT_FINISHES_OK_AND_ASSIGN(auto first_val, first); - ASSERT_EQ(TestInt(0), first_val); - - // Read the rest - for (int i = 0; i < 4; i++) { - auto next = readahead(); - ASSERT_FINISHES_OK_AND_ASSIGN(auto next_val, next); - ASSERT_EQ(TestInt(i + 1), next_val); - } - - // Next should be end - auto last = readahead(); - ASSERT_FINISHES_OK_AND_ASSIGN(auto last_val, last); - ASSERT_EQ(IterationTraits::End(), last_val); -} - -TEST(TestAsyncUtil, ReadaheadFailed) { - ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(4)); - std::atomic counter(0); - // All tasks are a little slow. The first task fails. - // The readahead will have spawned 9 more tasks and they - // should all pass - auto source = [thread_pool, &counter]() -> Future { - auto count = counter++; - return *thread_pool->Submit([count]() -> Result { - if (count == 0) { - return Status::Invalid("X"); - } - return TestInt(count); - }); - }; - auto readahead = MakeReadaheadGenerator(source, 10); - ASSERT_FINISHES_AND_RAISES(Invalid, readahead()); - SleepABit(); - - for (int i = 0; i < 9; i++) { - ASSERT_FINISHES_OK_AND_ASSIGN(auto next_val, readahead()); - ASSERT_EQ(TestInt(i + 1), next_val); - } - ASSERT_FINISHES_OK_AND_ASSIGN(auto after, readahead()); - - // It's possible that finished was set quickly and there - // are only 10 elements - if (after == IterationTraits::End()) { - return; - } - - // It's also possible that finished was too slow and there - // ended up being 11 elements - ASSERT_EQ(TestInt(10), after); - // There can't be 12 elements because SleepABit will prevent it - ASSERT_FINISHES_OK_AND_ASSIGN(auto definitely_last, readahead()); - ASSERT_EQ(IterationTraits::End(), definitely_last); -} - -TEST(TestAsyncIteratorTransform, SkipSome) { - auto original = AsyncVectorIt({1, 2, 3}); - auto filter = MakeFilter([](TestInt& t) { return t.value != 2; }); - auto filtered = MakeAsyncGenerator(std::move(original), filter); - AssertAsyncGeneratorMatch({"1", "3"}, std::move(filtered)); -} - -TEST(PushGenerator, Empty) { - PushGenerator gen; - auto producer = gen.producer(); - - auto fut = gen(); - AssertNotFinished(fut); - producer.Close(); - ASSERT_FINISHES_OK_AND_EQ(IterationTraits::End(), fut); - ASSERT_FINISHES_OK_AND_EQ(IterationTraits::End(), gen()); - - // Close idempotent - fut = gen(); - producer.Close(); - ASSERT_FINISHES_OK_AND_EQ(IterationTraits::End(), fut); - ASSERT_FINISHES_OK_AND_EQ(IterationTraits::End(), gen()); - ASSERT_FINISHES_OK_AND_EQ(IterationTraits::End(), gen()); -} - -TEST(PushGenerator, Success) { - PushGenerator gen; - auto producer = gen.producer(); - std::vector> futures; - - producer.Push(TestInt{1}); - producer.Push(TestInt{2}); - for (int i = 0; i < 3; ++i) { - futures.push_back(gen()); - } - ASSERT_FINISHES_OK_AND_EQ(TestInt{1}, futures[0]); - ASSERT_FINISHES_OK_AND_EQ(TestInt{2}, futures[1]); - AssertNotFinished(futures[2]); - - producer.Push(TestInt{3}); - ASSERT_FINISHES_OK_AND_EQ(TestInt{3}, futures[2]); - producer.Push(TestInt{4}); - futures.push_back(gen()); - ASSERT_FINISHES_OK_AND_EQ(TestInt{4}, futures[3]); - producer.Push(TestInt{5}); - producer.Close(); - for (int i = 0; i < 4; ++i) { - futures.push_back(gen()); - } - ASSERT_FINISHES_OK_AND_EQ(TestInt{5}, futures[4]); - for (int i = 5; i < 8; ++i) { - ASSERT_FINISHES_OK_AND_EQ(IterationTraits::End(), futures[i]); - } - ASSERT_FINISHES_OK_AND_EQ(IterationTraits::End(), gen()); -} - -TEST(PushGenerator, Errors) { - PushGenerator gen; - auto producer = gen.producer(); - std::vector> futures; - - producer.Push(TestInt{1}); - producer.Push(Status::Invalid("2")); - for (int i = 0; i < 3; ++i) { - futures.push_back(gen()); - } - ASSERT_FINISHES_OK_AND_EQ(TestInt{1}, futures[0]); - ASSERT_FINISHES_AND_RAISES(Invalid, futures[1]); - AssertNotFinished(futures[2]); - - producer.Push(Status::IOError("3")); - producer.Push(TestInt{4}); - ASSERT_FINISHES_AND_RAISES(IOError, futures[2]); - futures.push_back(gen()); - ASSERT_FINISHES_OK_AND_EQ(TestInt{4}, futures[3]); - producer.Close(); - ASSERT_FINISHES_OK_AND_EQ(IterationTraits::End(), gen()); -} - -TEST(PushGenerator, CloseEarly) { - PushGenerator gen; - auto producer = gen.producer(); - std::vector> futures; - - producer.Push(TestInt{1}); - producer.Push(TestInt{2}); - for (int i = 0; i < 3; ++i) { - futures.push_back(gen()); - } - producer.Close(); - producer.Push(TestInt{3}); - - ASSERT_FINISHES_OK_AND_EQ(TestInt{1}, futures[0]); - ASSERT_FINISHES_OK_AND_EQ(TestInt{2}, futures[1]); - ASSERT_FINISHES_OK_AND_EQ(IterationTraits::End(), futures[2]); - ASSERT_FINISHES_OK_AND_EQ(IterationTraits::End(), gen()); -} - -TEST(PushGenerator, Stress) { - const int NTHREADS = 20; - const int NVALUES = 2000; - const int NFUTURES = NVALUES + 100; - - PushGenerator gen; - auto producer = gen.producer(); - - std::atomic next_value{0}; - - auto producer_worker = [&]() { - while (true) { - int v = next_value.fetch_add(1); - if (v >= NVALUES) { - break; - } - producer.Push(v); - } - }; - - auto producer_main = [&]() { - std::vector threads; - for (int i = 0; i < NTHREADS; ++i) { - threads.emplace_back(producer_worker); - } - for (auto& thread : threads) { - thread.join(); - } - producer.Close(); - }; - - std::vector> results; - std::thread thread(producer_main); - for (int i = 0; i < NFUTURES; ++i) { - results.push_back(gen().result()); - } - thread.join(); - - std::unordered_set seen_values; - for (int i = 0; i < NVALUES; ++i) { - ASSERT_OK_AND_ASSIGN(auto v, results[i]); - ASSERT_EQ(seen_values.count(v.value), 0); - seen_values.insert(v.value); - } - for (int i = NVALUES; i < NFUTURES; ++i) { - ASSERT_OK_AND_EQ(IterationTraits::End(), results[i]); - } -} - } // namespace arrow diff --git a/cpp/src/arrow/util/stl_util_test.cc b/cpp/src/arrow/util/stl_util_test.cc index 4746c6f3700..2a8784e13a8 100644 --- a/cpp/src/arrow/util/stl_util_test.cc +++ b/cpp/src/arrow/util/stl_util_test.cc @@ -21,6 +21,7 @@ #include #include +#include "arrow/testing/gtest_util.h" #include "arrow/util/sort.h" #include "arrow/util/string.h" #include "arrow/util/vector.h" @@ -92,5 +93,80 @@ TEST(StlUtilTest, ArgSortPermute) { ExpectSortPermutation({b, c, d, e, a, f}, {4, 0, 1, 2, 3, 5}, 2); } +TEST(StlUtilTest, VectorFlatten) { + std::vector a{1, 2, 3}; + std::vector b{4, 5, 6}; + std::vector c{7, 8, 9}; + std::vector> vecs{a, b, c}; + auto actual = FlattenVectors(vecs); + std::vector expected{1, 2, 3, 4, 5, 6, 7, 8, 9}; + ASSERT_EQ(expected, actual); +} + +static std::string int_to_str(int val) { return std::to_string(val); } + +TEST(StlUtilTest, VectorMap) { + std::vector input{1, 2, 3}; + std::vector expected{"1", "2", "3"}; + + auto actual = MapVector(int_to_str, input); + ASSERT_EQ(expected, actual); + + auto bind_fn = std::bind(int_to_str, std::placeholders::_1); + actual = MapVector(bind_fn, input); + ASSERT_EQ(expected, actual); + + std::function std_fn = int_to_str; + actual = MapVector(std_fn, input); + ASSERT_EQ(expected, actual); + + actual = MapVector([](int val) { return std::to_string(val); }, input); + ASSERT_EQ(expected, actual); +} + +TEST(StlUtilTest, VectorMaybeMapFails) { + std::vector input{1, 2, 3}; + auto mapper = [](int item) -> Result { + if (item == 1) { + return Status::Invalid("XYZ"); + } + return std::to_string(item); + }; + ASSERT_RAISES(Invalid, MaybeMapVector(mapper, input)); +} + +TEST(StlUtilTest, VectorMaybeMap) { + std::vector input{1, 2, 3}; + std::vector expected{"1", "2", "3"}; + EXPECT_OK_AND_ASSIGN( + auto actual, + MaybeMapVector([](int item) -> Result { return std::to_string(item); }, + input)); + ASSERT_EQ(expected, actual); +} + +TEST(StlUtilTest, VectorUnwrapOrRaise) { + // TODO(ARROW-11998) There should be an easier way to construct these vectors + std::vector> all_good; + all_good.push_back(Result(MoveOnlyDataType(1))); + all_good.push_back(Result(MoveOnlyDataType(2))); + all_good.push_back(Result(MoveOnlyDataType(3))); + + std::vector> some_bad; + some_bad.push_back(Result(MoveOnlyDataType(1))); + some_bad.push_back(Result(Status::Invalid("XYZ"))); + some_bad.push_back(Result(Status::IOError("XYZ"))); + + EXPECT_OK_AND_ASSIGN(auto unwrapped, UnwrapOrRaise(std::move(all_good))); + std::vector expected; + expected.emplace_back(1); + expected.emplace_back(2); + expected.emplace_back(3); + + ASSERT_EQ(expected, unwrapped); + + ASSERT_RAISES(Invalid, UnwrapOrRaise(std::move(some_bad))); +} + } // namespace internal } // namespace arrow diff --git a/cpp/src/arrow/util/task_group.h b/cpp/src/arrow/util/task_group.h index 7a96bada013..3bb72f0d9cb 100644 --- a/cpp/src/arrow/util/task_group.h +++ b/cpp/src/arrow/util/task_group.h @@ -21,6 +21,7 @@ #include #include "arrow/status.h" +#include "arrow/type_fwd.h" #include "arrow/util/cancel.h" #include "arrow/util/functional.h" #include "arrow/util/macros.h" diff --git a/cpp/src/arrow/util/test_common.cc b/cpp/src/arrow/util/test_common.cc new file mode 100644 index 00000000000..0aaa02d5c2b --- /dev/null +++ b/cpp/src/arrow/util/test_common.cc @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/util/test_common.h" + +namespace arrow { + +TestInt::TestInt() : value(-999) {} +TestInt::TestInt(int i) : value(i) {} // NOLINT runtime/explicit +bool TestInt::operator==(const TestInt& other) const { return value == other.value; } + +std::ostream& operator<<(std::ostream& os, const TestInt& v) { + os << "{" << v.value << "}"; + return os; +} + +TestStr::TestStr() : value("") {} +TestStr::TestStr(const std::string& s) : value(s) {} // NOLINT runtime/explicit +TestStr::TestStr(const char* s) : value(s) {} // NOLINT runtime/explicit +TestStr::TestStr(const TestInt& test_int) { + if (IsIterationEnd(test_int)) { + value = ""; + } else { + value = std::to_string(test_int.value); + } +} + +bool TestStr::operator==(const TestStr& other) const { return value == other.value; } + +std::ostream& operator<<(std::ostream& os, const TestStr& v) { + os << "{\"" << v.value << "\"}"; + return os; +} + +std::vector RangeVector(unsigned int max) { + std::vector range(max); + for (unsigned int i = 0; i < max; i++) { + range[i] = i; + } + return range; +} + +Transformer MakeFilter(std::function filter) { + return [filter](TestInt next) -> Result> { + if (filter(next)) { + return TransformYield(TestStr(next)); + } else { + return TransformSkip(); + } + }; +} + +} // namespace arrow diff --git a/cpp/src/arrow/util/test_common.h b/cpp/src/arrow/util/test_common.h new file mode 100644 index 00000000000..e3162004b28 --- /dev/null +++ b/cpp/src/arrow/util/test_common.h @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/testing/gtest_util.h" +#include "arrow/util/iterator.h" + +namespace arrow { + +struct TestInt { + TestInt(); + TestInt(int i); // NOLINT runtime/explicit + int value; + + bool operator==(const TestInt& other) const; + + friend std::ostream& operator<<(std::ostream& os, const TestInt& v); +}; + +template <> +struct IterationTraits { + static TestInt End() { return TestInt(); } + static bool IsEnd(const TestInt& val) { return val == IterationTraits::End(); } +}; + +struct TestStr { + TestStr(); + TestStr(const std::string& s); // NOLINT runtime/explicit + TestStr(const char* s); // NOLINT runtime/explicit + explicit TestStr(const TestInt& test_int); + std::string value; + + bool operator==(const TestStr& other) const; + + friend std::ostream& operator<<(std::ostream& os, const TestStr& v); +}; + +template <> +struct IterationTraits { + static TestStr End() { return TestStr(); } + static bool IsEnd(const TestStr& val) { return val == IterationTraits::End(); } +}; + +std::vector RangeVector(unsigned int max); + +template +inline Iterator VectorIt(std::vector v) { + return MakeVectorIterator(std::move(v)); +} + +template +inline void AssertIteratorExhausted(Iterator& it) { + ASSERT_OK_AND_ASSIGN(T next, it.Next()); + ASSERT_TRUE(IsIterationEnd(next)); +} + +Transformer MakeFilter(std::function filter); + +} // namespace arrow diff --git a/cpp/src/arrow/util/type_fwd.h b/cpp/src/arrow/util/type_fwd.h index d29b130ebbd..f5d01518862 100644 --- a/cpp/src/arrow/util/type_fwd.h +++ b/cpp/src/arrow/util/type_fwd.h @@ -23,8 +23,6 @@ namespace detail { struct Empty; } // namespace detail -template -class Future; template class WeakFuture; class FutureWaiter; diff --git a/cpp/src/arrow/util/vector.h b/cpp/src/arrow/util/vector.h index cbd874dacae..67401d496e6 100644 --- a/cpp/src/arrow/util/vector.h +++ b/cpp/src/arrow/util/vector.h @@ -21,6 +21,9 @@ #include #include +#include "arrow/result.h" +#include "arrow/util/algorithm.h" +#include "arrow/util/functional.h" #include "arrow/util/logging.h" namespace arrow { @@ -81,5 +84,54 @@ std::vector FilterVector(std::vector values, Predicate&& predicate) { return values; } +/// \brief Like MapVector, but where the function can fail. +template , + typename To = typename internal::call_traits::return_type::ValueType> +Result> MaybeMapVector(Fn&& map, const std::vector& src) { + std::vector out; + out.reserve(src.size()); + ARROW_RETURN_NOT_OK(MaybeTransform(src.begin(), src.end(), std::back_inserter(out), + std::forward(map))); + return out; +} + +template ()(std::declval()))> +std::vector MapVector(Fn&& map, const std::vector& source) { + std::vector out; + out.reserve(source.size()); + std::transform(source.begin(), source.end(), std::back_inserter(out), + std::forward(map)); + return out; +} + +template +std::vector FlattenVectors(const std::vector>& vecs) { + std::size_t sum = 0; + for (const auto& vec : vecs) { + sum += vec.size(); + } + std::vector out; + out.reserve(sum); + for (const auto& vec : vecs) { + out.insert(out.end(), vec.begin(), vec.end()); + } + return out; +} + +template +Result> UnwrapOrRaise(std::vector>&& results) { + std::vector out; + out.reserve(results.size()); + auto end = std::make_move_iterator(results.end()); + for (auto it = std::make_move_iterator(results.begin()); it != end; it++) { + if (!it->ok()) { + return it->status(); + } + out.push_back(it->MoveValueUnsafe()); + } + return out; +} + } // namespace internal } // namespace arrow