diff --git a/velox/functions/prestosql/aggregates/ApproxPercentileAggregate.cpp b/velox/functions/prestosql/aggregates/ApproxPercentileAggregate.cpp index e0338a5a636..13f2eddeb9d 100644 --- a/velox/functions/prestosql/aggregates/ApproxPercentileAggregate.cpp +++ b/velox/functions/prestosql/aggregates/ApproxPercentileAggregate.cpp @@ -213,6 +213,7 @@ class ApproxPercentileAggregate : public exec::Aggregate { vector_size_t index) { digest.estimateQuantiles(percentiles, rawValues + elementsCount); result->setOffsetAndSize(index, elementsCount, percentiles.size()); + result->setNull(index, false); elementsCount += percentiles.size(); }); } else { diff --git a/velox/functions/prestosql/aggregates/tests/ApproxPercentileTest.cpp b/velox/functions/prestosql/aggregates/tests/ApproxPercentileTest.cpp index 18c458a05ea..387492db533 100644 --- a/velox/functions/prestosql/aggregates/tests/ApproxPercentileTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/ApproxPercentileTest.cpp @@ -20,10 +20,12 @@ #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h" +#include "velox/functions/lib/window/tests/WindowTestBase.h" using namespace facebook::velox::exec; using namespace facebook::velox::exec::test; using namespace facebook::velox::functions::aggregate::test; +using namespace facebook::velox::window::test; namespace facebook::velox::aggregate::test { @@ -565,5 +567,32 @@ TEST_F(ApproxPercentileTest, nullPercentile) { "Percentile cannot be null"); } +class ApproxPercentileWindowTest : public WindowTestBase { + protected: + void SetUp() override { + WindowTestBase::SetUp(); + random::setSeed(0); + } +}; + +TEST_F(ApproxPercentileWindowTest, window) { + auto data = makeRowVector( + {makeFlatVector({1, 2, 3}), + makeNullableFlatVector({10, std::nullopt, 30}), + makeArrayVectorFromJson({"[0.5]", "[0.5]", "[0.5]"})}); + auto expected = makeRowVector({ + makeFlatVector({1, 2, 3}), + makeNullableFlatVector({10, std::nullopt, 30}), + makeArrayVectorFromJson({"[0.5]", "[0.5]", "[0.5]"}), + makeNullableArrayVector({{{10}}, std::nullopt, {{30}}}), + }); + testWindowFunction( + {data}, + "approx_percentile(c1, c2)", + "order by c0", + "rows between current row and current row", + expected); +} + } // namespace } // namespace facebook::velox::aggregate::test