diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h index a9841c6651b72..1cbc1d051aa18 100644 --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -325,6 +325,20 @@ template auto drop_end(T &&RangeOrContainer, size_t N = 1) { std::prev(adl_end(RangeOrContainer), N)); } +/// Return a range covering \p RangeOrContainer with the first N elements +/// included. +template auto take_begin(T &&RangeOrContainer, size_t N = 1) { + return make_range(adl_begin(RangeOrContainer), + std::next(adl_begin(RangeOrContainer), N)); +} + +/// Return a range covering \p RangeOrContainer with the last N elements +/// included. +template auto take_end(T &&RangeOrContainer, size_t N = 1) { + return make_range(std::prev(adl_end(RangeOrContainer), N), + adl_end(RangeOrContainer)); +} + // mapped_iterator - This is a simple iterator adapter that causes a function to // be applied whenever operator* is invoked on the iterator. diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp index 966b1f01e8a31..ad71f8068366c 100644 --- a/llvm/unittests/ADT/STLExtrasTest.cpp +++ b/llvm/unittests/ADT/STLExtrasTest.cpp @@ -800,6 +800,35 @@ TEST(STLExtrasTest, DropEndDefaultTest) { EXPECT_THAT(drop_end(vec), ElementsAre(0, 1, 2, 3)); } +TEST(STLExtrasTest, TakeBeginTest) { + SmallVector vec{0, 1, 2, 3, 4}; + + for (int n = 0; n < 5; ++n) { + EXPECT_THAT(take_begin(vec, n), ElementsAreArray(ArrayRef(vec.data(), n))); + } +} + +TEST(STLExtrasTest, TakeBeginDefaultTest) { + SmallVector vec{0, 1, 2, 3, 4}; + + EXPECT_THAT(take_begin(vec), ElementsAre(0)); +} + +TEST(STLExtrasTest, TakeEndTest) { + SmallVector vec{0, 1, 2, 3, 4}; + + for (int n = 0; n < 5; ++n) { + EXPECT_THAT(take_end(vec, n), + ElementsAreArray(ArrayRef(&vec[vec.size() - n], n))); + } +} + +TEST(STLExtrasTest, TakeEndDefaultTest) { + SmallVector vec{0, 1, 2, 3, 4}; + + EXPECT_THAT(take_end(vec), ElementsAre(4)); +} + TEST(STLExtrasTest, MapRangeTest) { SmallVector Vec{0, 1, 2}; EXPECT_THAT(map_range(Vec, [](int V) { return V + 1; }),